diff --git a/docs/source/en/api/pipelines/unclip.md b/docs/source/en/api/pipelines/unclip.md
index c9a3164226..8011a4b533 100644
--- a/docs/source/en/api/pipelines/unclip.md
+++ b/docs/source/en/api/pipelines/unclip.md
@@ -7,6 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# unCLIP
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).
diff --git a/docs/source/en/api/pipelines/unidiffuser.md b/docs/source/en/api/pipelines/unidiffuser.md
index bce55b67ed..7d767f2db5 100644
--- a/docs/source/en/api/pipelines/unidiffuser.md
+++ b/docs/source/en/api/pipelines/unidiffuser.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# UniDiffuser
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
index 18b8207e3b..81cd242151 100644
--- a/docs/source/en/api/pipelines/wan.md
+++ b/docs/source/en/api/pipelines/wan.md
@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
```py
# pip install ftfy
import torch
- from diffusers import WanPipeline, AutoModel
+ from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
- vae = AutoModel.from_single_file(
+ vae = AutoencoderKLWan.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
- transformer = AutoModel.from_single_file(
+ transformer = WanTransformer3DModel.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
torch_dtype=torch.bfloat16
)
diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md
index 561df2017d..2be3631d84 100644
--- a/docs/source/en/api/pipelines/wuerstchen.md
+++ b/docs/source/en/api/pipelines/wuerstchen.md
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
# Würstchen
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index b18977720c..5a382c1c94 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -315,6 +315,8 @@ pipeline.load_lora_weights(
> [!TIP]
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
+If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
+
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
## Merge
diff --git a/docs/source/en/using-diffusers/batched_inference.md b/docs/source/en/using-diffusers/batched_inference.md
new file mode 100644
index 0000000000..b5e55c27ca
--- /dev/null
+++ b/docs/source/en/using-diffusers/batched_inference.md
@@ -0,0 +1,264 @@
+
+
+# Batch inference
+
+Batch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU.
+
+The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.
+
+
+
+
+For text-to-image, pass a list of prompts to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+prompts = [
+ "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+images = pipeline(
+ prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+images = pipeline(
+ prompt=prompts,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+For image-to-image, pass a list of input images and prompts to the pipeline.
+
+```py
+import torch
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+input_images = [
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+]
+
+prompts = [
+ "cinematic photo of a beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ image=input_images,
+ guidance_scale=8.0,
+ strength=0.5
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+
+images = pipeline(
+ prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ image=input_image,
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+input_images = [
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+]
+
+prompts = [
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ image=input_images,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+## Deterministic generation
+
+Enable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.
+
+Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch.
+
+Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
+
+```py
+generator = [torch.Generator(device="cuda").manual_seed(0)] * 3
+```
+
+Pass the `generator` to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)]
+prompts = [
+ "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ generator=generator
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+You can use this to iteratively select an image associated with a seed and then improve on it by crafting a more detailed prompt.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md
index 60b8fee754..ac9350f24c 100644
--- a/docs/source/en/using-diffusers/reusing_seeds.md
+++ b/docs/source/en/using-diffusers/reusing_seeds.md
@@ -136,53 +136,3 @@ result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="
print("L_inf dist =", abs(result1 - result2).max())
"L_inf dist = tensor(0., device='cuda:0')"
```
-
-## Deterministic batch generation
-
-A practical application of creating reproducible pipelines is *deterministic batch generation*. You generate a batch of images and select one image to improve with a more detailed prompt. The main idea is to pass a list of [Generator's](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed so you can reuse it.
-
-Let's use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint and generate a batch of images.
-
-```py
-import torch
-from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-)
-pipeline = pipeline.to("cuda")
-```
-
-Define four different `Generator`s and assign each `Generator` a seed (`0` to `3`). Then generate a batch of images and pick one to iterate on.
-
-> [!WARNING]
-> Use a list comprehension that iterates over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. If you multiply the `Generator` by the batch size integer, it only creates *one* `Generator` object that is used sequentially for each image in the batch.
->
-> ```py
-> [torch.Generator().manual_seed(seed)] * 4
-> ```
-
-```python
-generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
-prompt = "Labrador in the style of Vermeer"
-images = pipeline(prompt, generator=generator, num_images_per_prompt=4).images[0]
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-

-
-
-Let's improve the first image (you can choose any image you want) which corresponds to the `Generator` with seed `0`. Add some additional text to your prompt and then make sure you reuse the same `Generator` with seed `0`. All the generated images should resemble the first image.
-
-```python
-prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
-generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
-images = pipeline(prompt, generator=generator).images
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-

-
diff --git a/docs/source/en/using-diffusers/schedulers.md b/docs/source/en/using-diffusers/schedulers.md
index a3efbf2e80..aabb9dd31c 100644
--- a/docs/source/en/using-diffusers/schedulers.md
+++ b/docs/source/en/using-diffusers/schedulers.md
@@ -242,3 +242,15 @@ unet = UNet2DConditionModel.from_pretrained(
)
unet.save_pretrained("./local-unet", variant="non_ema")
```
+
+Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
+
+```py
+from diffusers import AutoModel
+
+unet = AutoModel.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
+)
+```
+
+You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index 3c8b75a088..53ee0f89e2 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -837,11 +837,6 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
- if args.train_norm_layers:
- for name, param in flux_transformer.named_parameters():
- if any(k in name for k in NORM_LAYER_PREFIXES):
- param.requires_grad = True
-
if args.lora_layers is not None:
if args.lora_layers != "all-linear":
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
@@ -879,6 +874,11 @@ def main(args):
)
flux_transformer.add_adapter(transformer_lora_config)
+ if args.train_norm_layers:
+ for name, param in flux_transformer.named_parameters():
+ if any(k in name for k in NORM_LAYER_PREFIXES):
+ param.requires_grad = True
+
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py
index 0c0426a1ef..6f6563ad64 100644
--- a/scripts/convert_cosmos_to_diffusers.py
+++ b/scripts/convert_cosmos_to_diffusers.py
@@ -95,7 +95,6 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
- # "extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 885d37fc8e..77971de414 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -412,6 +412,7 @@ else:
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
+ "FluxKontextInpaintPipeline",
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
@@ -1030,6 +1031,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextInpaintPipeline,
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
index c072165ded..412c057794 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -937,6 +937,27 @@ class LoraBaseMixin:
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
+ After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
+ can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
+ GPU before using those LoRA adapters for inference.
+
+ ```python
+ >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
+ >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
+ >>> pipe.set_adapters("adapter-1")
+ >>> image_1 = pipe(**kwargs)
+ >>> # switch to adapter-2, offload adapter-1
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
+ >>> pipe.set_adapters("adapter-2")
+ >>> image_2 = pipe(**kwargs)
+ >>> # switch back to adapter-1, offload adapter-2
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
+ >>> pipe.set_adapters("adapter-1")
+ >>> ...
+ ```
+
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
@@ -952,6 +973,10 @@ class LoraBaseMixin:
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
+ if adapter_name not in module.lora_A:
+ # it is sufficient to check lora_A
+ continue
+
module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 25e06c007f..df3aa6212f 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
return converted_state_dict
+def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
+ converted_state_dict = {}
+ original_state_dict_keys = list(original_state_dict.keys())
+ num_layers = 19
+ num_single_layers = 38
+ inner_dim = 3072
+ mlp_ratio = 4.0
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ original_block_prefix = "base_model.model."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norms
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
+ )
+
+ # Q, K, V
+ if lora_key == "lora_A":
+ sample_lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+
+ context_lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ else:
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
+ ),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
+
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
+ ),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
+
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
+
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
+
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
+ )
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
+ )
+
+ # single transformer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
+ )
+
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+
+ if lora_key == "lora_A":
+ lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
+
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
+ else:
+ q, k, v, mlp = torch.split(
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
+ split_size,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
+
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
+ split_size,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
+ )
+
+ for lora_key in ["lora_A", "lora_B"]:
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
+ )
+
+ if len(original_state_dict) > 0:
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
@@ -1603,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
+ has_time_projection_weight = any(
+ k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
+ )
- diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
- if diff_keys:
- for diff_k in diff_keys:
- param = original_state_dict[diff_k]
- # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
- # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
- # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
- # is okay to ignore because they do not affect the model output in a significant manner.
- threshold = 1.6e-2
- absdiff = param.abs().max() - param.abs().min()
- all_zero = torch.all(param == 0).item()
- all_absdiff_lower_than_threshold = absdiff < threshold
- if all_zero or all_absdiff_lower_than_threshold:
- logger.debug(
- f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
- )
- original_state_dict.pop(diff_k)
+ for key in list(original_state_dict.keys()):
+ if key.endswith((".diff", ".diff_b")) and "norm" in key:
+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
+ # in future if needed and they are not zeroed.
+ original_state_dict.pop(key)
+ logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
+
+ if "time_projection" in key and not has_time_projection_weight:
+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
+ # CausVid lora has the weight keys and the bias keys.
+ original_state_dict.pop(key)
# For the `diff_b` keys, we treat them as lora_bias.
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 4fea005cbc..4ee4808d80 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -41,6 +41,7 @@ from .lora_base import ( # noqa
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
+ _convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
@@ -2062,6 +2063,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
return_metadata=return_lora_metadata,
)
+ is_fal_kontext = any("base_model" in k for k in state_dict)
+ if is_fal_kontext:
+ state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
+ return cls._prepare_outputs(
+ state_dict,
+ metadata=metadata,
+ alphas=None,
+ return_alphas=return_alphas,
+ return_metadata=return_lora_metadata,
+ )
+
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index 0c6f3cda66..17ac81ca26 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -31,6 +31,7 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
+ convert_cosmos_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
@@ -135,6 +136,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
+ "WanVACETransformer3DModel": {
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
@@ -143,6 +148,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
+ "CosmosTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
}
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index d8d183304e..ee0786aa2d 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -126,7 +126,18 @@ CHECKPOINT_KEY_NAMES = {
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
+ "wan_vace": "vace_blocks.0.after_proj.bias",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
+ "cosmos-1.0": [
+ "net.x_embedder.proj.1.weight",
+ "net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
+ "net.extra_pos_embedder.pos_emb_h",
+ ],
+ "cosmos-2.0": [
+ "net.x_embedder.proj.1.weight",
+ "net.blocks.0.self_attn.q_proj.weight",
+ "net.pos_embedder.dim_spatial_range",
+ ],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -192,7 +203,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
+ "wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
+ "wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
+ "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
+ "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
+ "cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
+ "cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
+ "cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
+ "cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
+ "cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
+ "cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
}
# Use to configure model sample size when original config is provided
@@ -698,17 +719,44 @@ def infer_diffusers_model_type(checkpoint):
else:
target_key = "patch_embedding.weight"
- if checkpoint[target_key].shape[0] == 1536:
+ if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
+ if checkpoint[target_key].shape[0] == 1536:
+ model_type = "wan-vace-1.3B"
+ elif checkpoint[target_key].shape[0] == 5120:
+ model_type = "wan-vace-14B"
+
+ elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
+
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
+
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
+
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
+ if x_embedder_shape[1] == 68:
+ model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
+ elif x_embedder_shape[1] == 72:
+ model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
+ else:
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
+
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
+ if x_embedder_shape[1] == 68:
+ model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
+ elif x_embedder_shape[1] == 72:
+ model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
+ else:
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
+
else:
model_type = "v1"
@@ -3093,6 +3141,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # For the VACE model
+ "before_proj": "proj_in",
+ "after_proj": "proj_out",
}
for key in list(checkpoint.keys()):
@@ -3479,3 +3530,116 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict
+
+
+def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ def remove_keys_(key: str, state_dict):
+ state_dict.pop(key)
+
+ def rename_transformer_blocks_(key: str, state_dict):
+ block_index = int(key.split(".")[1].removeprefix("block"))
+ new_key = key
+ old_prefix = f"blocks.block{block_index}"
+ new_prefix = f"transformer_blocks.{block_index}"
+ new_key = new_prefix + new_key.removeprefix(old_prefix)
+ state_dict[new_key] = state_dict.pop(key)
+
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
+ "t_embedder.1": "time_embed.t_embedder",
+ "affline_norm": "time_embed.norm",
+ ".blocks.0.block.attn": ".attn1",
+ ".blocks.1.block.attn": ".attn2",
+ ".blocks.2.block": ".ff",
+ ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
+ ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
+ ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
+ ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
+ ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
+ ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
+ "to_q.0": "to_q",
+ "to_q.1": "norm_q",
+ "to_k.0": "to_k",
+ "to_k.1": "norm_k",
+ "to_v.0": "to_v",
+ "layer1": "net.0.proj",
+ "layer2": "net.2",
+ "proj.1": "proj",
+ "x_embedder": "patch_embed",
+ "extra_pos_embedder": "learnable_pos_embed",
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
+ "final_layer.adaLN_modulation.2": "norm_out.linear_2",
+ "final_layer.linear": "proj_out",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
+ "blocks.block": rename_transformer_blocks_,
+ "logvar.0.freqs": remove_keys_,
+ "logvar.0.phases": remove_keys_,
+ "logvar.1.weight": remove_keys_,
+ "pos_embedder.seq": remove_keys_,
+ }
+
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
+ "t_embedder.1": "time_embed.t_embedder",
+ "t_embedding_norm": "time_embed.norm",
+ "blocks": "transformer_blocks",
+ "adaln_modulation_self_attn.1": "norm1.linear_1",
+ "adaln_modulation_self_attn.2": "norm1.linear_2",
+ "adaln_modulation_cross_attn.1": "norm2.linear_1",
+ "adaln_modulation_cross_attn.2": "norm2.linear_2",
+ "adaln_modulation_mlp.1": "norm3.linear_1",
+ "adaln_modulation_mlp.2": "norm3.linear_2",
+ "self_attn": "attn1",
+ "cross_attn": "attn2",
+ "q_proj": "to_q",
+ "k_proj": "to_k",
+ "v_proj": "to_v",
+ "output_proj": "to_out.0",
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ "mlp.layer1": "ff.net.0.proj",
+ "mlp.layer2": "ff.net.2",
+ "x_embedder.proj.1": "patch_embed.proj",
+ "final_layer.adaln_modulation.1": "norm_out.linear_1",
+ "final_layer.adaln_modulation.2": "norm_out.linear_2",
+ "final_layer.linear": "proj_out",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
+ "accum_video_sample_counter": remove_keys_,
+ "accum_image_sample_counter": remove_keys_,
+ "accum_iteration": remove_keys_,
+ "accum_train_in_hours": remove_keys_,
+ "pos_embedder.seq": remove_keys_,
+ "pos_embedder.dim_spatial_range": remove_keys_,
+ "pos_embedder.dim_temporal_range": remove_keys_,
+ "_extra_state": remove_keys_,
+ }
+
+ PREFIX_KEY = "net."
+ if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
+ else:
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
+
+ state_dict_keys = list(converted_state_dict.keys())
+ for key in state_dict_keys:
+ new_key = key[:]
+ if new_key.startswith(PREFIX_KEY):
+ new_key = new_key.removeprefix(PREFIX_KEY)
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ state_dict_keys = list(converted_state_dict.keys())
+ for key in state_dict_keys:
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py
index 6c312b7a5a..3a6cb1ce6e 100644
--- a/src/diffusers/models/transformers/transformer_cosmos.py
+++ b/src/diffusers/models/transformers/transformer_cosmos.py
@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -377,7 +378,7 @@ class CosmosLearnablePositionalEmbed(nn.Module):
return (emb / norm).type_as(hidden_states)
-class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
+class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
index 0ae7f2c00d..5fb71b69f7 100644
--- a/src/diffusers/models/transformers/transformer_wan.py
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -71,14 +71,22 @@ class WanAttnProcessor2_0:
if rotary_emb is not None:
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
- dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
- x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
- return x_out.type_as(hidden_states)
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
+ x1, x2 = x[..., 0], x[..., 1]
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
- query = apply_rotary_emb(query, rotary_emb)
- key = apply_rotary_emb(key, rotary_emb)
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
@@ -179,7 +187,11 @@ class WanTimeTextImageEmbedding(nn.Module):
class WanRotaryPosEmbed(nn.Module):
def __init__(
- self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
):
super().__init__()
@@ -189,36 +201,52 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
-
- freqs = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
for dim in [t_dim, h_dim, w_dim]:
- freq = get_1d_rotary_pos_embed(
- dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
)
- freqs.append(freq)
- self.freqs = torch.cat(freqs, dim=1)
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
- freqs = self.freqs.to(hidden_states.device)
- freqs = freqs.split_with_sizes(
- [
- self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
- self.attention_head_dim // 6,
- self.attention_head_dim // 6,
- ],
- dim=1,
- )
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
- freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
- freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
- return freqs
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+
+ return freqs_cos, freqs_sin
class WanTransformerBlock(nn.Module):
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 892c6f5a4c..1904c02999 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -141,6 +141,7 @@ else:
"FluxPriorReduxPipeline",
"ReduxImageEncoder",
"FluxKontextPipeline",
+ "FluxKontextInpaintPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -610,6 +611,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextInpaintPipeline,
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
index 7d6a29ceca..598e3b5b6d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
@@ -29,7 +29,7 @@ from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
@@ -88,7 +88,7 @@ EXAMPLE_DOC_STRING = """
"""
-class BlipDiffusionControlNetPipeline(DiffusionPipeline):
+class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
@@ -116,6 +116,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py
index 117ce46f20..ea25c148e2 100644
--- a/src/diffusers/pipelines/flux/__init__.py
+++ b/src/diffusers/pipelines/flux/__init__.py
@@ -34,6 +34,7 @@ else:
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
_import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"]
+ _import_structure["pipeline_flux_kontext_inpaint"] = ["FluxKontextInpaintPipeline"]
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -54,6 +55,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
from .pipeline_flux_kontext import FluxKontextPipeline
+ from .pipeline_flux_kontext_inpaint import FluxKontextInpaintPipeline
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py
index b4f77cf019..ea49821adc 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py
@@ -163,9 +163,9 @@ class FluxControlPipeline(
TextualInversionLoaderMixin,
):
r"""
- The Flux pipeline for controllable text-to-image generation.
+ The Flux pipeline for controllable text-to-image generation with image conditions.
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+ Reference: https://bfl.ai/flux-1-tools
Args:
transformer ([`FluxTransformer2DModel`]):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
index 07b9b895a4..94901ee0b6 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
@@ -195,9 +195,9 @@ class FluxKontextPipeline(
FluxIPAdapterMixin,
):
r"""
- The Flux Kontext pipeline for text-to-image generation.
+ The Flux Kontext pipeline for image-to-image and text-to-image generation.
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
Args:
transformer ([`FluxTransformer2DModel`]):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
new file mode 100644
index 0000000000..2b4abe8b24
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
@@ -0,0 +1,1459 @@
+# Copyright 2025 ZenAI. All rights reserved.
+# author: @vuongminh1907
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ # Inpainting with text only
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> prompt = "Change the yellow dinosaur to green one"
+ >>> img_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
+ ... )
+ >>> mask_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
+ ... )
+
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
+ >>> image.save("kontext_inpainting_normal.png")
+ ```
+
+ # Inpainting with image conditioning
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> prompt = "Replace this ball"
+ >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
+ >>> mask_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
+ ... )
+ >>> image_reference_url = (
+ ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
+ ... )
+
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image_reference = load_image(image_reference_url)
+
+ >>> mask = pipe.mask_processor.blur(mask, blur_factor=12)
+ >>> image = pipe(
+ ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
+ ... ).images[0]
+ >>> image.save("kontext_inpainting_ref.png")
+ ```
+"""
+
+PREFERRED_KONTEXT_RESOLUTIONS = [
+ (672, 1568),
+ (688, 1504),
+ (720, 1456),
+ (752, 1392),
+ (800, 1328),
+ (832, 1248),
+ (880, 1184),
+ (944, 1104),
+ (1024, 1024),
+ (1104, 944),
+ (1184, 880),
+ (1248, 832),
+ (1328, 800),
+ (1392, 752),
+ (1456, 720),
+ (1504, 688),
+ (1568, 672),
+]
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FluxKontextInpaintPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Flux Kontext pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor],
+ timestep: int,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ image_reference: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ # Prepare image latents
+ image_latents = image_ids = None
+ if image is not None:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ # Prepare image reference latents
+ image_reference_latents = image_reference_ids = None
+ if image_reference is not None:
+ image_reference = image_reference.to(device=device, dtype=dtype)
+ if image_reference.shape[1] != self.latent_channels:
+ image_reference_latents = self._encode_vae_image(image=image_reference, generator=generator)
+ else:
+ image_reference_latents = image_reference
+ if batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_reference_latents.shape[0]
+ image_reference_latents = torch.cat([image_reference_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image_reference` of batch size {image_reference_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_reference_latents = torch.cat([image_reference_latents], dim=0)
+
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device=device, dtype=dtype)
+ latents = noise
+
+ image_latent_height, image_latent_width = image_latents.shape[2:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ image_ids = self._prepare_latent_image_ids(
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_ids[..., 0] = 1
+
+ if image_reference_latents is not None:
+ image_reference_latent_height, image_reference_latent_width = image_reference_latents.shape[2:]
+ image_reference_latents = self._pack_latents(
+ image_reference_latents,
+ batch_size,
+ num_channels_latents,
+ image_reference_latent_height,
+ image_reference_latent_width,
+ )
+ image_reference_ids = self._prepare_latent_image_ids(
+ batch_size, image_reference_latent_height // 2, image_reference_latent_width // 2, device, dtype
+ )
+ # image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_reference_ids[..., 0] = 1
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (
+ masked_image_latents - self.vae.config.shift_factor
+ ) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ image_reference: Optional[PipelineImageInput] = None,
+ mask_image: PipelineImageInput = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ padding_mask_crop: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ max_area: int = 1024**2,
+ _auto_resize: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image
+ to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point for the
+ masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If
+ it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is
+ a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can
+ also accept image latents as `image`, but if passing latents directly it is not encoded again.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+ max_area (`int`, defaults to `1024 ** 2`):
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
+ area while maintaining the aspect ratio.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_height, original_width = height, width
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ img = image[0] if isinstance(image, list) else image
+ image_height, image_width = self.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ image = self.image_processor.resize(image, image_height, image_width)
+
+ # Choose the resolution of the image to be the same as the image
+ width = image_width
+ height = image_height
+
+ # 2.1 Preprocess mask
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ image = self.image_processor.preprocess(
+ image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ else:
+ raise ValueError("image must be provided correctly for inpainting")
+
+ init_image = image.to(dtype=torch.float32)
+
+ # 2.1 Preprocess image_reference
+ if image_reference is not None and not (
+ isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels
+ ):
+ if (
+ isinstance(image_reference, list)
+ and isinstance(image_reference[0], torch.Tensor)
+ and image_reference[0].ndim == 4
+ ):
+ image_reference = torch.cat(image_reference, dim=0)
+ img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference
+ image_reference_height, image_reference_width = self.image_processor.get_default_height_width(
+ img_reference
+ )
+ aspect_ratio = image_reference_width / image_reference_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_reference_width, image_reference_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_reference_width = image_reference_width // multiple_of * multiple_of
+ image_reference_height = image_reference_height // multiple_of * multiple_of
+ image_reference = self.image_processor.resize(
+ image_reference, image_reference_height, image_reference_width
+ )
+ image_reference = self.image_processor.preprocess(
+ image_reference,
+ image_reference_height,
+ image_reference_width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+ else:
+ image_reference = None
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = (
+ self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image_reference,
+ )
+ )
+
+ if image_reference_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_reference_ids], dim=0) # dim 0 is sequence dimension
+ elif image_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ masked_image = init_image * (mask_condition < 0.5)
+
+ mask, _ = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ latent_model_input = latents
+ if image_reference_latents is not None:
+ latent_model_input = torch.cat([latents, image_reference_latents], dim=1)
+ elif image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 62f1735695..e9c732d164 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -722,6 +722,21 @@ class FluxInpaintPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class FluxKontextInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class FluxKontextPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py
index a81128fa44..1c5a9b00e9 100644
--- a/tests/lora/test_lora_layers_sd.py
+++ b/tests/lora/test_lora_layers_sd.py
@@ -120,7 +120,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
- "Lora not correctly set in text encoder",
+ "Lora not correctly set in unet",
)
# We will offload the first adapter in CPU and check if the offloading
@@ -187,7 +187,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
- "Lora not correctly set in text encoder",
+ "Lora not correctly set in unet",
)
for name, param in pipe.unet.named_parameters():
@@ -208,6 +208,53 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))
+ @slow
+ @require_torch_accelerator
+ def test_integration_set_lora_device_different_target_layers(self):
+ # fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
+ # layers, see #11833
+ from peft import LoraConfig
+
+ path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+ pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
+ # configs partly target the same, partly different layers
+ config0 = LoraConfig(target_modules=["to_k", "to_v"])
+ config1 = LoraConfig(target_modules=["to_k", "to_q"])
+ pipe.unet.add_adapter(config0, adapter_name="adapter-0")
+ pipe.unet.add_adapter(config1, adapter_name="adapter-1")
+ pipe = pipe.to(torch_device)
+
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.unet),
+ "Lora not correctly set in unet",
+ )
+
+ # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
+ modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
+ modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
+ self.assertNotEqual(modules_adapter_0, modules_adapter_1)
+ self.assertTrue(modules_adapter_0 - modules_adapter_1)
+ self.assertTrue(modules_adapter_1 - modules_adapter_0)
+
+ # setting both separately works
+ pipe.set_lora_device(["adapter-0"], "cpu")
+ pipe.set_lora_device(["adapter-1"], "cpu")
+
+ for name, module in pipe.unet.named_modules():
+ if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device == torch.device("cpu"))
+ elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device == torch.device("cpu"))
+
+ # setting both at once also works
+ pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
+
+ for name, module in pipe.unet.named_modules():
+ if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device != torch.device("cpu"))
+ elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device != torch.device("cpu"))
+
@slow
@nightly
diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py
new file mode 100644
index 0000000000..a7eb740804
--- /dev/null
+++ b/tests/lora/test_lora_layers_wanvace.py
@@ -0,0 +1,222 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+import safetensors.torch
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
+from diffusers.utils.import_utils import is_peft_available
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ is_flaky,
+ require_peft_backend,
+ require_peft_version_greater,
+ skip_mps,
+ torch_device,
+)
+
+
+if is_peft_available():
+ from peft.utils import get_peft_model_state_dict
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = WanVACEPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 32,
+ "freq_dim": 16,
+ "ffn_dim": 16,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 16,
+ "vace_layers": [0],
+ "vace_in_channels": 72,
+ }
+ transformer_cls = WanVACETransformer3DModel
+ vae_kwargs = {
+ "base_dim": 3,
+ "z_dim": 4,
+ "dim_mult": [1, 1, 1, 1],
+ "latents_mean": torch.randn(4).numpy().tolist(),
+ "latents_std": torch.randn(4).numpy().tolist(),
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+ vae_cls = AutoencoderKLWan
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+
+ @property
+ def output_shape(self):
+ return (1, 9, 16, 16, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ sizes = (4, 4)
+ height, width = 16, 16
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+ video = [Image.new("RGB", (height, width))] * num_frames
+ mask = [Image.new("L", (height, width), 0)] * num_frames
+
+ pipeline_inputs = {
+ "video": video,
+ "mask": mask,
+ "prompt": "",
+ "num_frames": num_frames,
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": height,
+ "width": height,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
+
+ @pytest.mark.xfail(
+ condition=True,
+ reason="RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same",
+ strict=True,
+ )
+ def test_layerwise_casting_inference_denoiser(self):
+ super().test_layerwise_casting_inference_denoiser()
+
+ @require_peft_version_greater("0.13.2")
+ def test_lora_exclude_modules_wanvace(self):
+ scheduler_cls = self.scheduler_classes[0]
+ exclude_module_name = "vace_blocks.0.proj_out"
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # only supported for `denoiser` now
+ denoiser_lora_config.target_modules = ["proj_out"]
+ denoiser_lora_config.exclude_modules = [exclude_module_name]
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ # The state dict shouldn't contain the modules to be excluded from LoRA.
+ state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
+ self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
+ self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
+ output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
+ pipe.unload_lora_weights()
+
+ # Check in the loaded state dict.
+ loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
+ self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
+
+ # Check in the state dict obtained after loading LoRA.
+ pipe.load_lora_weights(tmpdir)
+ state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
+ self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
+ self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
+
+ output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
+ "LoRA should change outputs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
+ "Lora outputs should match.",
+ )
+
+ @is_flaky
+ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
+ super().test_simple_inference_with_text_denoiser_lora_and_scale()
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index dcc7ae16a4..def81ecd64 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -1350,7 +1350,6 @@ class ModelTesterMixin:
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
- print(f" new_model.hf_device_map:{new_model.hf_device_map}")
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
@@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin:
"""
+ different_shapes_for_compilation = None
+
def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2056,11 +2057,13 @@ class LoraHotSwappingForModelTesterMixin:
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
+ - optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
+ different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -2110,19 +2113,30 @@ class LoraHotSwappingForModelTesterMixin:
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
- model = torch.compile(model, mode="reduce-overhead")
+ model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
- output0_after = model(**inputs_dict)["sample"]
- assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
+ # additionally check if dynamic compilation works.
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output0_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
- output1_after = model(**inputs_dict)["sample"]
- assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output1_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name
name = "does-not-exist"
@@ -2240,3 +2254,23 @@ class LoraHotSwappingForModelTesterMixin:
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)])
+ @require_torch_version_greater("2.7.1")
+ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
+ different_shapes_for_compilation = self.different_shapes_for_compilation
+ if different_shapes_for_compilation is None:
+ pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
+ # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
+ # variable to represent input sizes that are the same. For more details,
+ # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
+ torch.fx.experimental._config.use_duck_shape = False
+
+ target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ self.check_model_hotswap(
+ do_compile=True,
+ rank0=rank0,
+ rank1=rank1,
+ target_modules0=target_modules,
+ )
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 4552b2e1f5..68b5c02bc0 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -186,6 +186,10 @@ class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
new file mode 100644
index 0000000000..615209264d
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
@@ -0,0 +1,190 @@
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ FluxKontextInpaintPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.utils.testing_utils import floats_tensor, torch_device
+
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+)
+
+
+class FluxKontextInpaintPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+):
+ pipeline_class = FluxKontextInpaintPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["image", "prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ mask_image = torch.ones((1, 1, 32, 32)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 48,
+ "strength": 0.8,
+ "output_type": "np",
+ "_auto_resize": False,
+ }
+ return inputs
+
+ def test_flux_inpaint_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 56)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+ # Because output shape is the same as the input shape, we need to create a dummy image and mask image
+ image = floats_tensor((1, 3, height, width), rng=random.Random(0)).to(torch_device)
+ mask_image = torch.ones((1, 1, height, width)).to(torch_device)
+
+ inputs.update(
+ {
+ "height": height,
+ "width": width,
+ "max_area": height * width,
+ "image": image,
+ "mask_image": mask_image,
+ }
+ )
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_flux_true_cfg(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("generator")
+
+ no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ inputs["negative_prompt"] = "bad quality"
+ inputs["true_cfg_scale"] = 2.0
+ true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ assert not np.allclose(no_true_cfg_out, true_cfg_out)
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 69dd79bb56..f87778b260 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -1378,7 +1378,6 @@ class PipelineTesterMixin:
for component in pipe_fp16.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
-
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
@@ -1386,17 +1385,20 @@ class PipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
-
output = pipe(**inputs)[0]
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
-
output_fp16 = pipe_fp16(**fp16_inputs)[0]
+
+ if isinstance(output, torch.Tensor):
+ output = output.cpu()
+ output_fp16 = output_fp16.cpu()
+
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
- assert max_diff < 1e-2
+ assert max_diff < expected_max_diff
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index c5497d1c8d..06116cac3a 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -98,7 +98,14 @@ class Base4bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- torch.use_deterministic_algorithms(True)
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index 383cdd6849..2ea4cdfde8 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -99,7 +99,14 @@ class Base8bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- torch.use_deterministic_algorithms(True)
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index 5d1fa4c22e..0d786de7e7 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -15,6 +15,8 @@ from diffusers import (
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
+ WanTransformer3DModel,
+ WanVACETransformer3DModel,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
@@ -577,3 +579,71 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
).to(torch_device, self.torch_dtype),
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
+
+
+class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "encoder_hidden_states_image": torch.randn(
+ (1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0)
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanVACETransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states": torch.randn(
+ (1, 96, 2, 64, 64),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states_scale": torch.randn(
+ (8,),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
diff --git a/utils/print_env.py b/utils/print_env.py
index 2d2acb59d5..2fe0777daf 100644
--- a/utils/print_env.py
+++ b/utils/print_env.py
@@ -28,6 +28,16 @@ print("Python version:", sys.version)
print("OS platform:", platform.platform())
print("OS architecture:", platform.machine())
+try:
+ import psutil
+
+ vm = psutil.virtual_memory()
+ total_gb = vm.total / (1024**3)
+ available_gb = vm.available / (1024**3)
+ print(f"Total RAM: {total_gb:.2f} GB")
+ print(f"Available RAM: {available_gb:.2f} GB")
+except ImportError:
+ pass
try:
import torch