diff --git a/docs/source/en/api/pipelines/unclip.md b/docs/source/en/api/pipelines/unclip.md
index c9a3164226..8011a4b533 100644
--- a/docs/source/en/api/pipelines/unclip.md
+++ b/docs/source/en/api/pipelines/unclip.md
@@ -7,6 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# unCLIP
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).
diff --git a/docs/source/en/api/pipelines/unidiffuser.md b/docs/source/en/api/pipelines/unidiffuser.md
index bce55b67ed..7d767f2db5 100644
--- a/docs/source/en/api/pipelines/unidiffuser.md
+++ b/docs/source/en/api/pipelines/unidiffuser.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# UniDiffuser
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
index 18b8207e3b..81cd242151 100644
--- a/docs/source/en/api/pipelines/wan.md
+++ b/docs/source/en/api/pipelines/wan.md
@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
```py
# pip install ftfy
import torch
- from diffusers import WanPipeline, AutoModel
+ from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
- vae = AutoModel.from_single_file(
+ vae = AutoencoderKLWan.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
- transformer = AutoModel.from_single_file(
+ transformer = WanTransformer3DModel.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
torch_dtype=torch.bfloat16
)
diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md
index 561df2017d..2be3631d84 100644
--- a/docs/source/en/api/pipelines/wuerstchen.md
+++ b/docs/source/en/api/pipelines/wuerstchen.md
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
# Würstchen
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index b18977720c..5a382c1c94 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -315,6 +315,8 @@ pipeline.load_lora_weights(
> [!TIP]
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
+If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
+
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
## Merge
diff --git a/docs/source/en/using-diffusers/batched_inference.md b/docs/source/en/using-diffusers/batched_inference.md
new file mode 100644
index 0000000000..b5e55c27ca
--- /dev/null
+++ b/docs/source/en/using-diffusers/batched_inference.md
@@ -0,0 +1,264 @@
+
+
+# Batch inference
+
+Batch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU.
+
+The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.
+
+
+
+
+For text-to-image, pass a list of prompts to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+prompts = [
+ "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+images = pipeline(
+ prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+images = pipeline(
+ prompt=prompts,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+For image-to-image, pass a list of input images and prompts to the pipeline.
+
+```py
+import torch
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+input_images = [
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+]
+
+prompts = [
+ "cinematic photo of a beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ image=input_images,
+ guidance_scale=8.0,
+ strength=0.5
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+
+images = pipeline(
+ prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ image=input_image,
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+input_images = [
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+]
+
+prompts = [
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ image=input_images,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+## Deterministic generation
+
+Enable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.
+
+Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch.
+
+Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
+
+```py
+generator = [torch.Generator(device="cuda").manual_seed(0)] * 3
+```
+
+Pass the `generator` to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)]
+prompts = [
+ "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ generator=generator
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+You can use this to iteratively select an image associated with a seed and then improve on it by crafting a more detailed prompt.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md
index 60b8fee754..ac9350f24c 100644
--- a/docs/source/en/using-diffusers/reusing_seeds.md
+++ b/docs/source/en/using-diffusers/reusing_seeds.md
@@ -136,53 +136,3 @@ result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="
print("L_inf dist =", abs(result1 - result2).max())
"L_inf dist = tensor(0., device='cuda:0')"
```
-
-## Deterministic batch generation
-
-A practical application of creating reproducible pipelines is *deterministic batch generation*. You generate a batch of images and select one image to improve with a more detailed prompt. The main idea is to pass a list of [Generator's](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed so you can reuse it.
-
-Let's use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint and generate a batch of images.
-
-```py
-import torch
-from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-)
-pipeline = pipeline.to("cuda")
-```
-
-Define four different `Generator`s and assign each `Generator` a seed (`0` to `3`). Then generate a batch of images and pick one to iterate on.
-
-> [!WARNING]
-> Use a list comprehension that iterates over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. If you multiply the `Generator` by the batch size integer, it only creates *one* `Generator` object that is used sequentially for each image in the batch.
->
-> ```py
-> [torch.Generator().manual_seed(seed)] * 4
-> ```
-
-```python
-generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
-prompt = "Labrador in the style of Vermeer"
-images = pipeline(prompt, generator=generator, num_images_per_prompt=4).images[0]
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-

-
-
-Let's improve the first image (you can choose any image you want) which corresponds to the `Generator` with seed `0`. Add some additional text to your prompt and then make sure you reuse the same `Generator` with seed `0`. All the generated images should resemble the first image.
-
-```python
-prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
-generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
-images = pipeline(prompt, generator=generator).images
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-

-
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index 3c8b75a088..53ee0f89e2 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -837,11 +837,6 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
- if args.train_norm_layers:
- for name, param in flux_transformer.named_parameters():
- if any(k in name for k in NORM_LAYER_PREFIXES):
- param.requires_grad = True
-
if args.lora_layers is not None:
if args.lora_layers != "all-linear":
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
@@ -879,6 +874,11 @@ def main(args):
)
flux_transformer.add_adapter(transformer_lora_config)
+ if args.train_norm_layers:
+ for name, param in flux_transformer.named_parameters():
+ if any(k in name for k in NORM_LAYER_PREFIXES):
+ param.requires_grad = True
+
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py
index 0c0426a1ef..6f6563ad64 100644
--- a/scripts/convert_cosmos_to_diffusers.py
+++ b/scripts/convert_cosmos_to_diffusers.py
@@ -95,7 +95,6 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
- # "extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 3670243de8..4ade3374d8 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -244,13 +244,20 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
- # create LoraConfig
- lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
-
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
+ # create LoraConfig
+ lora_config = _create_lora_config(
+ state_dict,
+ network_alphas,
+ metadata,
+ rank,
+ model_state_dict=self.state_dict(),
+ adapter_name=adapter_name,
+ )
+
#
torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
- freqs = self.freqs.to(hidden_states.device)
- freqs = freqs.split_with_sizes(
- [
- self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
- self.attention_head_dim // 6,
- self.attention_head_dim // 6,
- ],
- dim=1,
- )
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
- freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
- freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
- return freqs
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+
+ return freqs_cos, freqs_sin
class WanTransformerBlock(nn.Module):
diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py
index acff268c9b..63b4a109ff 100644
--- a/src/diffusers/schedulers/scheduling_scm.py
+++ b/src/diffusers/schedulers/scheduling_scm.py
@@ -168,7 +168,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
- print(f"Set timesteps: {self.timesteps}")
self._step_index = None
self._begin_index = None
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index 3907bdd5b3..651fa27294 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
-def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
+def get_peft_kwargs(
+ rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
+):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
else:
lora_alpha = set(network_alpha_dict.values()).pop()
- # layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# for now we know that the "bias" keys are only associated with `lora_B`.
@@ -195,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
"use_dora": use_dora,
"lora_bias": lora_bias,
}
+
+ # Example: try load FusionX LoRA into Wan VACE
+ exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
+ if exclude_modules:
+ if not is_peft_version(">=", "0.14.0"):
+ msg = """
+It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
+version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
+peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
+https://github.com/huggingface/diffusers/issues/new
+ """
+ logger.debug(msg)
+ else:
+ lora_config_kwargs.update({"exclude_modules": exclude_modules})
+
return lora_config_kwargs
@@ -294,11 +310,7 @@ def check_peft_version(min_version: str) -> None:
def _create_lora_config(
- state_dict,
- network_alphas,
- metadata,
- rank_pattern_dict,
- is_unet: bool = True,
+ state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
from peft import LoraConfig
@@ -306,7 +318,12 @@ def _create_lora_config(
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
- rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
+ rank_pattern_dict,
+ network_alpha_dict=network_alphas,
+ peft_state_dict=state_dict,
+ is_unet=is_unet,
+ model_state_dict=model_state_dict,
+ adapter_name=adapter_name,
)
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
if warn_msg:
logger.warning(warn_msg)
+
+
+def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
+ """
+ Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
+ `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
+ doesn't exist in `peft_state_dict`.
+ """
+ if model_state_dict is None:
+ return
+ all_modules = set()
+ string_to_replace = f"{adapter_name}." if adapter_name else ""
+
+ for name in model_state_dict.keys():
+ if string_to_replace:
+ name = name.replace(string_to_replace, "")
+ if "." in name:
+ module_name = name.rsplit(".", 1)[0]
+ all_modules.add(module_name)
+
+ target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
+ exclude_modules = list(all_modules - target_modules_set)
+
+ return exclude_modules
diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py
index 95ec44b2bf..fe26a56e77 100644
--- a/tests/lora/test_lora_layers_wan.py
+++ b/tests/lora/test_lora_layers_wan.py
@@ -24,7 +24,11 @@ from diffusers import (
WanPipeline,
WanTransformer3DModel,
)
-from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ require_peft_backend,
+ skip_mps,
+)
sys.path.append(".")
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index acd6f5f343..91ca188137 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import inspect
import os
import re
@@ -291,9 +292,21 @@ class PeftLoraLoaderMixinTests:
return modules_to_save
- def check_if_adapters_added_correctly(
- self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
- ):
+ def _get_exclude_modules(self, pipe):
+ from diffusers.utils.peft_utils import _derive_exclude_modules
+
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ denoiser = "unet" if self.unet_kwargs is not None else "transformer"
+ modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
+ denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
+ pipe.unload_lora_weights()
+ denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
+ exclude_modules = _derive_exclude_modules(
+ denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
+ )
+ return exclude_modules
+
+ def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
@@ -345,7 +358,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -428,7 +441,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -484,7 +497,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -522,7 +535,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
@@ -554,7 +567,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -589,7 +602,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -640,7 +653,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -691,7 +704,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -734,7 +747,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -775,7 +788,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -819,7 +832,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
@@ -857,7 +870,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -893,7 +906,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
@@ -1010,7 +1023,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
@@ -1032,7 +1045,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
@@ -1759,7 +1772,7 @@ class PeftLoraLoaderMixinTests:
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1850,7 +1863,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
@@ -1937,7 +1950,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
@@ -2119,7 +2132,7 @@ class PeftLoraLoaderMixinTests:
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
@@ -2237,7 +2250,7 @@ class PeftLoraLoaderMixinTests:
)
pipe = self.pipeline_class(**components)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
@@ -2290,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -2309,6 +2322,77 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)
+ def test_lora_unload_add_adapter(self):
+ """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
+ scheduler_cls = self.scheduler_classes[0]
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ # unload and then add.
+ pipe.unload_lora_weights()
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ @require_peft_version_greater("0.13.2")
+ def test_lora_exclude_modules(self):
+ """
+ Test to check if `exclude_modules` works or not. It works in the following way:
+ we first create a pipeline and insert LoRA config into it. We then derive a `set`
+ of modules to exclude by investigating its denoiser state dict and denoiser LoRA
+ state dict.
+
+ We then create a new LoRA config to include the `exclude_modules` and perform tests.
+ """
+ scheduler_cls = self.scheduler_classes[0]
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # only supported for `denoiser` now
+ pipe_cp = copy.deepcopy(pipe)
+ pipe_cp, _ = self.add_adapters_to_pipeline(
+ pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
+ pipe_cp.to("cpu")
+ del pipe_cp
+
+ denoiser_lora_config.exclude_modules = denoiser_exclude_modules
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(tmpdir)
+
+ output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
+ "LoRA should change outputs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
+ "Lora outputs should match.",
+ )
+
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index dcc7ae16a4..def81ecd64 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -1350,7 +1350,6 @@ class ModelTesterMixin:
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
- print(f" new_model.hf_device_map:{new_model.hf_device_map}")
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
@@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin:
"""
+ different_shapes_for_compilation = None
+
def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2056,11 +2057,13 @@ class LoraHotSwappingForModelTesterMixin:
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
+ - optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
+ different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -2110,19 +2113,30 @@ class LoraHotSwappingForModelTesterMixin:
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
- model = torch.compile(model, mode="reduce-overhead")
+ model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
- output0_after = model(**inputs_dict)["sample"]
- assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
+ # additionally check if dynamic compilation works.
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output0_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
- output1_after = model(**inputs_dict)["sample"]
- assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output1_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name
name = "does-not-exist"
@@ -2240,3 +2254,23 @@ class LoraHotSwappingForModelTesterMixin:
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)])
+ @require_torch_version_greater("2.7.1")
+ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
+ different_shapes_for_compilation = self.different_shapes_for_compilation
+ if different_shapes_for_compilation is None:
+ pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
+ # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
+ # variable to represent input sizes that are the same. For more details,
+ # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
+ torch.fx.experimental._config.use_duck_shape = False
+
+ target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ self.check_model_hotswap(
+ do_compile=True,
+ rank0=rank0,
+ rank1=rank1,
+ target_modules0=target_modules,
+ )
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 4552b2e1f5..68b5c02bc0 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -186,6 +186,10 @@ class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 69dd79bb56..f87778b260 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -1378,7 +1378,6 @@ class PipelineTesterMixin:
for component in pipe_fp16.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
-
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
@@ -1386,17 +1385,20 @@ class PipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
-
output = pipe(**inputs)[0]
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
-
output_fp16 = pipe_fp16(**fp16_inputs)[0]
+
+ if isinstance(output, torch.Tensor):
+ output = output.cpu()
+ output_fp16 = output_fp16.cpu()
+
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
- assert max_diff < 1e-2
+ assert max_diff < expected_max_diff
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index c5497d1c8d..06116cac3a 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -98,7 +98,14 @@ class Base4bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- torch.use_deterministic_algorithms(True)
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index 383cdd6849..2ea4cdfde8 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -99,7 +99,14 @@ class Base8bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- torch.use_deterministic_algorithms(True)
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index e59e817f6b..9f3cd4bf95 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -16,6 +16,8 @@ from diffusers import (
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
+ WanTransformer3DModel,
+ WanVACETransformer3DModel,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
@@ -583,6 +585,74 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
}
+class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "encoder_hidden_states_image": torch.randn(
+ (1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0)
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanVACETransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states": torch.randn(
+ (1, 96, 2, 64, 64),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states_scale": torch.randn(
+ (8,),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests):
torch_dtype = torch.bfloat16