diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3170e8e1c9..d79165a3ea 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -457,6 +457,8 @@ title: Flux - local: api/pipelines/control_flux_inpaint title: FluxControlInpaint + - local: api/pipelines/framepack + title: Framepack - local: api/pipelines/hidream title: HiDream-I1 - local: api/pipelines/hunyuandit @@ -573,6 +575,8 @@ title: UniDiffuser - local: api/pipelines/value_guided_sampling title: Value-guided sampling + - local: api/pipelines/visualcloze + title: VisualCloze - local: api/pipelines/wan title: Wan - local: api/pipelines/wuerstchen diff --git a/docs/source/en/api/pipelines/framepack.md b/docs/source/en/api/pipelines/framepack.md new file mode 100644 index 0000000000..bad8005166 --- /dev/null +++ b/docs/source/en/api/pipelines/framepack.md @@ -0,0 +1,209 @@ + + +# Framepack + +
+ LoRA +
+ +[Packing Input Frame Context in Next-Frame Prediction Models for Video Generation](https://arxiv.org/abs/2504.12626) by Lvmin Zhang and Maneesh Agrawala. + +*We present a neural network structure, FramePack, to train next-frame (or next-frame-section) prediction models for video generation. The FramePack compresses input frames to make the transformer context length a fixed number regardless of the video length. As a result, we are able to process a large number of frames using video diffusion with computation bottleneck similar to image diffusion. This also makes the training video batch sizes significantly higher (batch sizes become comparable to image diffusion training). We also propose an anti-drifting sampling method that generates frames in inverted temporal order with early-established endpoints to avoid exposure bias (error accumulation over iterations). Finally, we show that existing video diffusion models can be finetuned with FramePack, and their visual quality may be improved because the next-frame prediction supports more balanced diffusion schedulers with less extreme flow shift timesteps.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## Available models + +| Model name | Description | +|:---|:---| +- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | Trained with the "inverted anti-drifting" strategy as described in the paper. Inference requires setting `sampling_type="inverted_anti_drifting"` when running the pipeline. | +- [`lllyasviel/FramePack_F1_I2V_HY_20250503`](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503) | Trained with a novel anti-drifting strategy but inference is performed in "vanilla" strategy as described in the paper. Inference requires setting `sampling_type="vanilla"` when running the pipeline. | + +## Usage + +Refer to the pipeline documentation for basic usage examples. The following section contains examples of offloading, different sampling methods, quantization, and more. + +### First and last frame to video + +The following example shows how to use Framepack with start and end image controls, using the inverted anti-drifiting sampling model. + +```python +import torch +from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel +from diffusers.utils import export_to_video, load_image +from transformers import SiglipImageProcessor, SiglipVisionModel + +transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16 +) +feature_extractor = SiglipImageProcessor.from_pretrained( + "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" +) +image_encoder = SiglipVisionModel.from_pretrained( + "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 +) +pipe = HunyuanVideoFramepackPipeline.from_pretrained( + "hunyuanvideo-community/HunyuanVideo", + transformer=transformer, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + torch_dtype=torch.float16, +) + +# Enable memory optimizations +pipe.enable_model_cpu_offload() +pipe.vae.enable_tiling() + +prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." +first_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" +) +last_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" +) +output = pipe( + image=first_image, + last_image=last_image, + prompt=prompt, + height=512, + width=512, + num_frames=91, + num_inference_steps=30, + guidance_scale=9.0, + generator=torch.Generator().manual_seed(0), + sampling_type="inverted_anti_drifting", +).frames[0] +export_to_video(output, "output.mp4", fps=30) +``` + +### Vanilla sampling + +The following example shows how to use Framepack with the F1 model trained with vanilla sampling but new regulation approach for anti-drifting. + +```python +import torch +from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel +from diffusers.utils import export_to_video, load_image +from transformers import SiglipImageProcessor, SiglipVisionModel + +transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + "lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16 +) +feature_extractor = SiglipImageProcessor.from_pretrained( + "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" +) +image_encoder = SiglipVisionModel.from_pretrained( + "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 +) +pipe = HunyuanVideoFramepackPipeline.from_pretrained( + "hunyuanvideo-community/HunyuanVideo", + transformer=transformer, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + torch_dtype=torch.float16, +) + +# Enable memory optimizations +pipe.enable_model_cpu_offload() +pipe.vae.enable_tiling() + +image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" +) +output = pipe( + image=image, + prompt="A penguin dancing in the snow", + height=832, + width=480, + num_frames=91, + num_inference_steps=30, + guidance_scale=9.0, + generator=torch.Generator().manual_seed(0), + sampling_type="vanilla", +).frames[0] +export_to_video(output, "output.mp4", fps=30) +``` + +### Group offloading + +Group offloading ([`~hooks.apply_group_offloading`]) provides aggressive memory optimizations for offloading internal parts of any model to the CPU, with possibly no additional overhead to generation time. If you have very low VRAM available, this approach may be suitable for you depending on the amount of CPU RAM available. + +```python +import torch +from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel +from diffusers.hooks import apply_group_offloading +from diffusers.utils import export_to_video, load_image +from transformers import SiglipImageProcessor, SiglipVisionModel + +transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + "lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16 +) +feature_extractor = SiglipImageProcessor.from_pretrained( + "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" +) +image_encoder = SiglipVisionModel.from_pretrained( + "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 +) +pipe = HunyuanVideoFramepackPipeline.from_pretrained( + "hunyuanvideo-community/HunyuanVideo", + transformer=transformer, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + torch_dtype=torch.float16, +) + +# Enable group offloading +onload_device = torch.device("cuda") +offload_device = torch.device("cpu") +list(map( + lambda x: apply_group_offloading(x, onload_device, offload_device, offload_type="leaf_level", use_stream=True, low_cpu_mem_usage=True), + [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer] +)) +pipe.image_encoder.to(onload_device) +pipe.vae.to(onload_device) +pipe.vae.enable_tiling() + +image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" +) +output = pipe( + image=image, + prompt="A penguin dancing in the snow", + height=832, + width=480, + num_frames=91, + num_inference_steps=30, + guidance_scale=9.0, + generator=torch.Generator().manual_seed(0), + sampling_type="vanilla", +).frames[0] +print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB") +export_to_video(output, "output.mp4", fps=30) +``` + +## HunyuanVideoFramepackPipeline + +[[autodoc]] HunyuanVideoFramepackPipeline + - all + - __call__ + +## HunyuanVideoPipelineOutput + +[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput + diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md index a2c8e8b20d..5d068c8b6e 100644 --- a/docs/source/en/api/pipelines/hunyuan_video.md +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -52,7 +52,6 @@ The following models are available for the image-to-video pipeline: | [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. | | [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). | | [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) | -- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | lllyasviel's paper introducing a new technique for long-context video generation called [Framepack](https://arxiv.org/abs/2504.12626). | ## Quantization diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 26d17733c1..3464d88145 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -31,12 +31,103 @@ Available models: | Model name | Recommended dtype | |:-------------:|:-----------------:| -| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` | -| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | -| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` | +| [`LTX Video 2B 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` | +| [`LTX Video 2B 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | +| [`LTX Video 2B 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` | +| [`LTX Video 13B 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) | `torch.bfloat16` | +| [`LTX Video Spatial Upscaler 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) | `torch.bfloat16` | Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository. +## Recommended settings for generation + +For the best results, it is recommended to follow the guidelines mentioned in the official LTX Video [repository](https://github.com/Lightricks/LTX-Video). + +- Some variants of LTX Video are guidance-distilled. For guidance-distilled models, `guidance_scale` must be set to `1.0`. For any other models, `guidance_scale` should be set higher (e.g., `5.0`) for good generation quality. +- For variants with a timestep-aware VAE (LTXV 0.9.1 and above), it is recommended to set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`. +- For variants that support interpolation between multiple conditioning images and videos (LTXV 0.9.5 and above), it is recommended to use similar looking images/videos for the best results. High divergence between the conditionings may lead to abrupt transitions in the generated video. + +## Using LTX Video 13B 0.9.7 + +LTX Video 0.9.7 comes with a spatial latent upscaler and a 13B parameter transformer. The inference involves generating a low resolution video first, which is very fast, followed by upscaling and refining the generated video. + + + +```python +import torch +from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline +from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition +from diffusers.utils import export_to_video, load_video + +pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16) +pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16) +pipe.to("cuda") +pipe_upsample.to("cuda") +pipe.vae.enable_tiling() + +def round_to_nearest_resolution_acceptable_by_vae(height, width): + height = height - (height % pipe.vae_temporal_compression_ratio) + width = width - (width % pipe.vae_temporal_compression_ratio) + return height, width + +video = load_video( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" +)[:21] # Use only the first 21 frames as conditioning +condition1 = LTXVideoCondition(video=video, frame_index=0) + +prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" +expected_height, expected_width = 768, 1152 +downscale_factor = 2 / 3 +num_frames = 161 + +# Part 1. Generate video at smaller resolution +# Text-only conditioning is also supported without the need to pass `conditions` +downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor) +downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width) +latents = pipe( + conditions=[condition1], + prompt=prompt, + negative_prompt=negative_prompt, + width=downscaled_width, + height=downscaled_height, + num_frames=num_frames, + num_inference_steps=30, + generator=torch.Generator().manual_seed(0), + output_type="latent", +).frames + +# Part 2. Upscale generated video using latent upsampler with fewer inference steps +# The available latent upsampler upscales the height/width by 2x +upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2 +upscaled_latents = pipe_upsample( + latents=latents, + output_type="latent" +).frames + +# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended) +video = pipe( + conditions=[condition1], + prompt=prompt, + negative_prompt=negative_prompt, + width=upscaled_width, + height=upscaled_height, + num_frames=num_frames, + denoise_strength=0.4, # Effectively, 4 inference steps out of 10 + num_inference_steps=10, + latents=upscaled_latents, + decode_timestep=0.05, + image_cond_noise_scale=0.025, + generator=torch.Generator().manual_seed(0), + output_type="pil", +).frames[0] + +# Part 4. Downscale the video to the expected resolution +video = [frame.resize((expected_width, expected_height)) for frame in video] + +export_to_video(video, "output.mp4", fps=24) +``` + ## Loading Single Files Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format. @@ -204,6 +295,12 @@ export_to_video(video, "ship.mp4", fps=24) - all - __call__ +## LTXLatentUpsamplePipeline + +[[autodoc]] LTXLatentUpsamplePipeline + - all + - __call__ + ## LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index 6a8e82a692..95b50ce608 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -89,6 +89,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation | | [Value-guided planning](value_guided_sampling) | value guided sampling | | [Wuerstchen](wuerstchen) | text2image | +| [VisualCloze](visualcloze) | text2image, image2image, subject driven generation, inpainting, style transfer, image restoration, image editing, [depth,normal,edge,pose]2image, [depth,normal,edge,pose]-estimation, virtual try-on, image relighting | ## DiffusionPipeline diff --git a/docs/source/en/api/pipelines/visualcloze.md b/docs/source/en/api/pipelines/visualcloze.md new file mode 100644 index 0000000000..f43cfbeb5a --- /dev/null +++ b/docs/source/en/api/pipelines/visualcloze.md @@ -0,0 +1,300 @@ + + +# VisualCloze + +[VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning](https://arxiv.org/abs/2504.07960) is an innovative in-context learning based universal image generation framework that offers key capabilities: +1. Support for various in-domain tasks +2. Generalization to unseen tasks through in-context learning +3. Unify multiple tasks into one step and generate both target image and intermediate results +4. Support reverse-engineering conditions from target images + +## Overview + +The abstract from the paper is: + +*Recent progress in diffusion models significantly advances various image generation tasks. However, the current mainstream approach remains focused on building task-specific models, which have limited efficiency when supporting a wide range of different needs. While universal models attempt to address this limitation, they face critical challenges, including generalizable task instruction, appropriate task distributions, and unified architectural design. To tackle these challenges, we propose VisualCloze, a universal image generation framework, which supports a wide range of in-domain tasks, generalization to unseen ones, unseen unification of multiple tasks, and reverse generation. Unlike existing methods that rely on language-based task instruction, leading to task ambiguity and weak generalization, we integrate visual in-context learning, allowing models to identify tasks from visual demonstrations. Meanwhile, the inherent sparsity of visual task distributions hampers the learning of transferable knowledge across tasks. To this end, we introduce Graph200K, a graph-structured dataset that establishes various interrelated tasks, enhancing task density and transferable knowledge. Furthermore, we uncover that our unified image generation formulation shared a consistent objective with image infilling, enabling us to leverage the strong generative priors of pre-trained infilling models without modifying the architectures. The codes, dataset, and models are available at https://visualcloze.github.io.* + +## Inference + +### Model loading + +VisualCloze is a two-stage cascade pipeline, containing `VisualClozeGenerationPipeline` and `VisualClozeUpsamplingPipeline`. +- In `VisualClozeGenerationPipeline`, each image is downsampled before concatenating images into a grid layout, avoiding excessively high resolutions. VisualCloze releases two models suitable for diffusers, i.e., [VisualClozePipeline-384](https://huggingface.co/VisualCloze/VisualClozePipeline-384) and [VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-384), which downsample images to resolutions of 384 and 512, respectively. +- `VisualClozeUpsamplingPipeline` uses [SDEdit](https://arxiv.org/abs/2108.01073) to enable high-resolution image synthesis. + +The `VisualClozePipeline` integrates both stages to support convenient end-to-end sampling, while also allowing users to utilize each pipeline independently as needed. + +### Input Specifications + +#### Task and Content Prompts +- Task prompt: Required to describe the generation task intention +- Content prompt: Optional description or caption of the target image +- When content prompt is not needed, pass `None` +- For batch inference, pass `List[str|None]` + +#### Image Input Format +- Format: `List[List[Image|None]]` +- Structure: + - All rows except the last represent in-context examples + - Last row represents the current query (target image set to `None`) +- For batch inference, pass `List[List[List[Image|None]]]` + +#### Resolution Control +- Default behavior: + - Initial generation in the first stage: area of ${pipe.resolution}^2$ + - Upsampling in the second stage: 3x factor +- Custom resolution: Adjust using `upsampling_height` and `upsampling_width` parameters + +### Examples + +For comprehensive examples covering a wide range of tasks, please refer to the [Online Demo](https://huggingface.co/spaces/VisualCloze/VisualCloze) and [GitHub Repository](https://github.com/lzyhha/VisualCloze). Below are simple examples for three cases: mask-to-image conversion, edge detection, and subject-driven generation. + +#### Example for mask2image + +```python +import torch +from diffusers import VisualClozePipeline +from diffusers.utils import load_image + +pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Load in-context images (make sure the paths are correct and accessible) +image_paths = [ + # in-context examples + [ + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg'), + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg'), + ], + # query with the target image + [ + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg'), + None, # No image needed for the target image + ], +] + +# Task and content prompt +task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding." +content_prompt = """Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. +The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. +Its plumage is a mix of dark brown and golden hues, with intricate feather details. +The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. +The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, +soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, +tranquil, majestic, wildlife photography.""" + +# Run the pipeline +image_result = pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image_paths, + upsampling_width=1344, + upsampling_height=768, + upsampling_strength=0.4, + guidance_scale=30, + num_inference_steps=30, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0) +).images[0][0] + +# Save the resulting image +image_result.save("visualcloze.png") +``` + +#### Example for edge-detection + +```python +import torch +from diffusers import VisualClozePipeline +from diffusers.utils import load_image + +pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Load in-context images (make sure the paths are correct and accessible) +image_paths = [ + # in-context examples + [ + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_image.jpg'), + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_edge.jpg'), + ], + [ + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_image.jpg'), + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_edge.jpg'), + ], + # query with the target image + [ + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_query_image.jpg'), + None, # No image needed for the target image + ], +] + +# Task and content prompt +task_prompt = "Each row illustrates a pathway from [IMAGE1] a sharp and beautifully composed photograph to [IMAGE2] edge map with natural well-connected outlines using a clear logical task." +content_prompt = "" + +# Run the pipeline +image_result = pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image_paths, + upsampling_width=864, + upsampling_height=1152, + upsampling_strength=0.4, + guidance_scale=30, + num_inference_steps=30, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0) +).images[0][0] + +# Save the resulting image +image_result.save("visualcloze.png") +``` + +#### Example for subject-driven generation + +```python +import torch +from diffusers import VisualClozePipeline +from diffusers.utils import load_image + +pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Load in-context images (make sure the paths are correct and accessible) +image_paths = [ + # in-context examples + [ + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_reference.jpg'), + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_depth.jpg'), + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_image.jpg'), + ], + [ + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_reference.jpg'), + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_depth.jpg'), + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_image.jpg'), + ], + # query with the target image + [ + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_reference.jpg'), + load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_depth.jpg'), + None, # No image needed for the target image + ], +] + +# Task and content prompt +task_prompt = """Each row describes a process that begins with [IMAGE1] an image containing the key object, +[IMAGE2] depth map revealing gray-toned spatial layers and results in +[IMAGE3] an image with artistic qualitya high-quality image with exceptional detail.""" +content_prompt = """A vintage porcelain collector's item. Beneath a blossoming cherry tree in early spring, +this treasure is photographed up close, with soft pink petals drifting through the air and vibrant blossoms framing the scene.""" + +# Run the pipeline +image_result = pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image_paths, + upsampling_width=1024, + upsampling_height=1024, + upsampling_strength=0.2, + guidance_scale=30, + num_inference_steps=30, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0) +).images[0][0] + +# Save the resulting image +image_result.save("visualcloze.png") +``` + +#### Utilize each pipeline independently + +```python +import torch +from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline +from diffusers.utils import load_image +from PIL import Image + +pipe = VisualClozeGenerationPipeline.from_pretrained( + "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +image_paths = [ + # in-context examples + [ + load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg" + ), + load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg" + ), + ], + # query with the target image + [ + load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg" + ), + None, # No image needed for the target image + ], +] +task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding." +content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography." + +# Stage 1: Generate initial image +image = pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image_paths, + guidance_scale=30, + num_inference_steps=30, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0), +).images[0][0] + +# Stage 2 (optional): Upsample the generated image +pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe) +pipe_upsample.to("cuda") + +mask_image = Image.new("RGB", image.size, (255, 255, 255)) + +image = pipe_upsample( + image=image, + mask_image=mask_image, + prompt=content_prompt, + width=1344, + height=768, + strength=0.4, + guidance_scale=30, + num_inference_steps=30, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0), +).images[0] + +image.save("visualcloze.png") +``` + +## VisualClozePipeline + +[[autodoc]] VisualClozePipeline + - all + - __call__ + +## VisualClozeGenerationPipeline + +[[autodoc]] VisualClozeGenerationPipeline + - all + - __call__ diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index aa5c069d29..6619735f18 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1114,17 +1114,22 @@ def main(args): ) # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, num_cycles=args.lr_num_cycles, power=args.lr_power, ) @@ -1156,8 +1161,14 @@ def main(args): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py index a5d89f77d6..9f96ef944a 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py @@ -812,7 +812,7 @@ def main(args): if args.scale_lr: args.learning_rate = ( - args.learning_rat * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 2e966d5d11..256312cc72 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -7,7 +7,15 @@ from accelerate import init_empty_weights from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXConditionPipeline, + LTXLatentUpsamplePipeline, + LTXPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel def remove_keys_(key: str, state_dict: Dict[str, Any]): @@ -123,17 +131,10 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: state_dict[new_key] = state_dict.pop(old_key) -def convert_transformer( - ckpt_path: str, - dtype: torch.dtype, - version: str = "0.9.0", -): +def convert_transformer(ckpt_path: str, config, dtype: torch.dtype): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(load_file(ckpt_path)) - config = {} - if version == "0.9.5": - config["_use_causal_rope_fix"] = True with init_empty_weights(): transformer = LTXVideoTransformer3DModel(**config) @@ -180,8 +181,59 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype): return vae +def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype): + original_state_dict = get_state_dict(load_file(ckpt_path)) + + with init_empty_weights(): + latent_upsampler = LTXLatentUpsamplerModel(**config) + + latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True) + latent_upsampler.to(dtype) + return latent_upsampler + + +def get_transformer_config(version: str) -> Dict[str, Any]: + if version == "0.9.7": + config = { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 4096, + "attention_bias": True, + "attention_out_bias": True, + } + else: + config = { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 64, + "cross_attention_dim": 2048, + "num_layers": 28, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 4096, + "attention_bias": True, + "attention_out_bias": True, + } + return config + + def get_vae_config(version: str) -> Dict[str, Any]: - if version == "0.9.0": + if version in ["0.9.0"]: config = { "in_channels": 3, "out_channels": 3, @@ -210,7 +262,7 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_causal": False, "timestep_conditioning": False, } - elif version == "0.9.1": + elif version in ["0.9.1"]: config = { "in_channels": 3, "out_channels": 3, @@ -240,7 +292,39 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_causal": False, } VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) - elif version == "0.9.5": + elif version in ["0.9.5"]: + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (128, 256, 512, 1024, 2048), + "down_block_types": ( + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": True, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "scaling_factor": 1.0, + "encoder_causal": True, + "decoder_causal": False, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + } + VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) + elif version in ["0.9.7"]: config = { "in_channels": 3, "out_channels": 3, @@ -275,12 +359,33 @@ def get_vae_config(version: str) -> Dict[str, Any]: return config +def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]: + if version == "0.9.7": + config = { + "in_channels": 128, + "mid_channels": 512, + "num_blocks_per_stage": 4, + "dims": 3, + "spatial_upsample": True, + "temporal_upsample": False, + } + else: + raise ValueError(f"Unsupported version: {version}") + return config + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") + parser.add_argument( + "--spatial_latent_upsampler_path", + type=str, + default=None, + help="Path to original spatial latent upsampler checkpoint", + ) parser.add_argument( "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" ) @@ -294,7 +399,11 @@ def get_args(): parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") parser.add_argument( - "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model" + "--version", + type=str, + default="0.9.0", + choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"], + help="Version of the LTX model", ) return parser.parse_args() @@ -320,11 +429,9 @@ if __name__ == "__main__": variant = VARIANT_MAPPING[args.dtype] output_path = Path(args.output_path) - if args.save_pipeline: - assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None - if args.transformer_ckpt_path is not None: - transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) + config = get_transformer_config(args.version) + transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype) if not args.save_pipeline: transformer.save_pretrained( output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant @@ -336,6 +443,16 @@ if __name__ == "__main__": if not args.save_pipeline: vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant) + if args.spatial_latent_upsampler_path is not None: + config = get_spatial_latent_upsampler_config(args.version) + latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler( + args.spatial_latent_upsampler_path, config, dtype + ) + if not args.save_pipeline: + latent_upsampler.save_pretrained( + output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant + ) + if args.save_pipeline: text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) @@ -348,7 +465,7 @@ if __name__ == "__main__": for param in text_encoder.parameters(): param.data = param.data.contiguous() - if args.version == "0.9.5": + if args.version in ["0.9.5", "0.9.7"]: scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False) else: scheduler = FlowMatchEulerDiscreteScheduler( @@ -360,12 +477,40 @@ if __name__ == "__main__": shift_terminal=0.1, ) - pipe = LTXPipeline( - scheduler=scheduler, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - transformer=transformer, - ) - - pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB") + if args.version in ["0.9.0", "0.9.1", "0.9.5"]: + pipe = LTXPipeline( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + pipe.save_pretrained( + output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB" + ) + elif args.version in ["0.9.7"]: + pipe = LTXConditionPipeline( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + pipe_upsample = LTXLatentUpsamplePipeline( + vae=vae, + latent_upsampler=latent_upsampler, + ) + pipe.save_pretrained( + (output_path / "ltx_pipeline").as_posix(), + safe_serialization=True, + variant=variant, + max_shard_size="5GB", + ) + pipe_upsample.save_pretrained( + (output_path / "ltx_upsample_pipeline").as_posix(), + safe_serialization=True, + variant=variant, + max_shard_size="5GB", + ) + else: + raise ValueError(f"Unsupported version: {args.version}") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9e172dc3e7..9ab973351c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -420,6 +420,7 @@ else: "LEditsPPPipelineStableDiffusionXL", "LTXConditionPipeline", "LTXImageToVideoPipeline", + "LTXLatentUpsamplePipeline", "LTXPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", @@ -520,6 +521,8 @@ else: "VersatileDiffusionPipeline", "VersatileDiffusionTextToImagePipeline", "VideoToVideoSDPipeline", + "VisualClozeGenerationPipeline", + "VisualClozePipeline", "VQDiffusionPipeline", "WanImageToVideoPipeline", "WanPipeline", @@ -1001,6 +1004,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LEditsPPPipelineStableDiffusionXL, LTXConditionPipeline, LTXImageToVideoPipeline, + LTXLatentUpsamplePipeline, LTXPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, @@ -1100,6 +1104,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, VideoToVideoSDPipeline, + VisualClozeGenerationPipeline, + VisualClozePipeline, VQDiffusionPipeline, WanImageToVideoPipeline, WanPipeline, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index b7da4fb746..7a970c5c51 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -57,6 +57,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { "WanTransformer3DModel": lambda model_cls, weights: weights, "CogView4Transformer2DModel": lambda model_cls, weights: weights, "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, + "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index a2f27b765a..ade2e457d8 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -187,9 +187,8 @@ class FromOriginalModelMixin: original_config (`str`, *optional*): Dict or path to a yaml file containing the configuration for the model in its original format. If a dict is provided, it will be used to initialize the model configuration. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 677a991f05..e475fe6bee 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -161,9 +161,8 @@ class MultiAdapter(ModelMixin): pretrained_model_path (`os.PathLike`): A path to a *directory* containing model weights saved using [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 1b742463aa..7cc9c50c0b 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -52,9 +52,8 @@ class AutoModel(ConfigMixin): cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. diff --git a/src/diffusers/models/controlnets/multicontrolnet.py b/src/diffusers/models/controlnets/multicontrolnet.py index cb022212de..87a9522949 100644 --- a/src/diffusers/models/controlnets/multicontrolnet.py +++ b/src/diffusers/models/controlnets/multicontrolnet.py @@ -130,9 +130,8 @@ class MultiControlNetModel(ModelMixin): A path to a *directory* containing model weights saved using [`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g., `./my_model_directory/controlnet`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): diff --git a/src/diffusers/models/controlnets/multicontrolnet_union.py b/src/diffusers/models/controlnets/multicontrolnet_union.py index 40cfa6a884..d5506dc186 100644 --- a/src/diffusers/models/controlnets/multicontrolnet_union.py +++ b/src/diffusers/models/controlnets/multicontrolnet_union.py @@ -143,9 +143,8 @@ class MultiControlNetUnionModel(ModelMixin): A path to a *directory* containing model weights saved using [`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g., `./my_model_directory/controlnet`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 2a22bc09ad..55ce0cf79f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -787,9 +787,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 90ff30279f..4debb868d9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -268,7 +268,12 @@ else: ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"] + _import_structure["ltx"] = [ + "LTXPipeline", + "LTXImageToVideoPipeline", + "LTXConditionPipeline", + "LTXLatentUpsamplePipeline", + ] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( @@ -281,6 +286,7 @@ else: _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["omnigen"] = ["OmniGenPipeline"] + _import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] @@ -636,7 +642,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline + from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .marigold import ( @@ -727,6 +733,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: UniDiffuserPipeline, UniDiffuserTextDecoder, ) + from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wuerstchen import ( WuerstchenCombinedPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 6a5f6098b6..ed8ad79ca7 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -322,9 +322,8 @@ class AutoPipelineForText2Image(ConfigMixin): - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the - dtype is automatically derived from the model's weights. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -619,8 +618,7 @@ class AutoPipelineForImage2Image(ConfigMixin): saved using [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the - dtype is automatically derived from the model's weights. + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -930,8 +928,7 @@ class AutoPipelineForInpainting(ConfigMixin): saved using [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the - dtype is automatically derived from the model's weights. + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 64cd6ac45f..572ad5096b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -607,6 +607,39 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile return latents + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def prepare_latents( self, image, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 43db740c6d..fdcc4e53a9 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -14,6 +14,7 @@ import inspect import math +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -91,6 +92,7 @@ EXAMPLE_DOC_STRING = """ ... num_inference_steps=30, ... guidance_scale=9.0, ... generator=torch.Generator().manual_seed(0), + ... sampling_type="inverted_anti_drifting", ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=30) ``` @@ -138,6 +140,7 @@ EXAMPLE_DOC_STRING = """ ... num_inference_steps=30, ... guidance_scale=9.0, ... generator=torch.Generator().manual_seed(0), + ... sampling_type="inverted_anti_drifting", ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=30) ``` @@ -232,6 +235,11 @@ def retrieve_timesteps( return timesteps, num_inference_steps +class FramepackSamplingType(str, Enum): + VANILLA = "vanilla" + INVERTED_ANTI_DRIFTING = "inverted_anti_drifting" + + class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): r""" Pipeline for text-to-video generation using HunyuanVideo. @@ -455,6 +463,11 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix prompt_embeds=None, callback_on_step_end_tensor_inputs=None, prompt_template=None, + image=None, + image_latents=None, + last_image=None, + last_image_latents=None, + sampling_type=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -493,6 +506,21 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" ) + sampling_types = [x.value for x in FramepackSamplingType.__members__.values()] + if sampling_type not in sampling_types: + raise ValueError(f"`sampling_type` has to be one of '{sampling_types}' but is '{sampling_type}'") + + if image is not None and image_latents is not None: + raise ValueError("Only one of `image` or `image_latents` can be passed.") + if last_image is not None and last_image_latents is not None: + raise ValueError("Only one of `last_image` or `last_image_latents` can be passed.") + if sampling_type != FramepackSamplingType.INVERTED_ANTI_DRIFTING and ( + last_image is not None or last_image_latents is not None + ): + raise ValueError( + 'Only `"inverted_anti_drifting"` inference type supports `last_image` or `last_image_latents`.' + ) + def prepare_latents( self, batch_size: int = 1, @@ -623,6 +651,7 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix callback_on_step_end_tensor_inputs: List[str] = ["latents"], prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, max_sequence_length: int = 256, + sampling_type: FramepackSamplingType = FramepackSamplingType.INVERTED_ANTI_DRIFTING, ): r""" The call function to the pipeline for generation. @@ -735,6 +764,11 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix prompt_embeds, callback_on_step_end_tensor_inputs, prompt_template, + image, + image_latents, + last_image, + last_image_latents, + sampling_type, ) has_neg_prompt = negative_prompt is not None or ( @@ -806,18 +840,6 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix num_channels_latents = self.transformer.config.in_channels window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames) - # Specific to the released checkpoint: https://huggingface.co/lllyasviel/FramePackI2V_HY - # TODO: find a more generic way in future if there are more checkpoints - history_sizes = [1, 2, 16] - history_latents = torch.zeros( - batch_size, - num_channels_latents, - sum(history_sizes), - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - device=device, - dtype=torch.float32, - ) history_video = None total_generated_latent_frames = 0 @@ -829,38 +851,92 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix last_image, dtype=torch.float32, device=device, generator=generator ) - latent_paddings = list(reversed(range(num_latent_sections))) - if num_latent_sections > 4: - latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0] + # Specific to the released checkpoints: + # - https://huggingface.co/lllyasviel/FramePackI2V_HY + # - https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503 + # TODO: find a more generic way in future if there are more checkpoints + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + history_sizes = [1, 2, 16] + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + + elif sampling_type == FramepackSamplingType.VANILLA: + history_sizes = [16, 2, 1] + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + history_latents = torch.cat([history_latents, image_latents], dim=2) + total_generated_latent_frames += 1 + + else: + assert False # 6. Prepare guidance condition guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0 # 7. Denoising loop for k in range(num_latent_sections): - is_first_section = k == 0 - is_last_section = k == num_latent_sections - 1 - latent_padding_size = latent_paddings[k] * latent_window_size + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + latent_paddings = list(reversed(range(num_latent_sections))) + if num_latent_sections > 4: + latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0] - indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])) - ( - indices_prefix, - indices_padding, - indices_latents, - indices_postfix, - indices_latents_history_2x, - indices_latents_history_4x, - ) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0) - # Inverted anti-drifting sampling: Figure 2(c) in the paper - indices_clean_latents = torch.cat([indices_prefix, indices_postfix], dim=0) + is_first_section = k == 0 + is_last_section = k == num_latent_sections - 1 + latent_padding_size = latent_paddings[k] * latent_window_size - latents_prefix = image_latents - latents_postfix, latents_history_2x, latents_history_4x = history_latents[ - :, :, : sum(history_sizes) - ].split(history_sizes, dim=2) - if last_image is not None and is_first_section: - latents_postfix = last_image_latents - latents_clean = torch.cat([latents_prefix, latents_postfix], dim=2) + indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])) + ( + indices_prefix, + indices_padding, + indices_latents, + indices_latents_history_1x, + indices_latents_history_2x, + indices_latents_history_4x, + ) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0) + # Inverted anti-drifting sampling: Figure 2(c) in the paper + indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix = image_latents + latents_history_1x, latents_history_2x, latents_history_4x = history_latents[ + :, :, : sum(history_sizes) + ].split(history_sizes, dim=2) + if last_image is not None and is_first_section: + latents_history_1x = last_image_latents + latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2) + + elif sampling_type == FramepackSamplingType.VANILLA: + indices = torch.arange(0, sum([1, *history_sizes, latent_window_size])) + ( + indices_prefix, + indices_latents_history_4x, + indices_latents_history_2x, + indices_latents_history_1x, + indices_latents, + ) = indices.split([1, *history_sizes, latent_window_size], dim=0) + indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix = image_latents + latents_history_4x, latents_history_2x, latents_history_1x = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2) + + else: + assert False latents = self.prepare_latents( batch_size, @@ -960,13 +1036,26 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix if XLA_AVAILABLE: xm.mark_step() - if is_last_section: - latents = torch.cat([image_latents, latents], dim=2) + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + if is_last_section: + latents = torch.cat([image_latents, latents], dim=2) + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([latents, history_latents], dim=2) + real_history_latents = history_latents[:, :, :total_generated_latent_frames] + section_latent_frames = ( + (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) + ) + index_slice = (slice(None), slice(None), slice(0, section_latent_frames)) - total_generated_latent_frames += latents.shape[2] - history_latents = torch.cat([latents, history_latents], dim=2) + elif sampling_type == FramepackSamplingType.VANILLA: + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + section_latent_frames = latent_window_size * 2 + index_slice = (slice(None), slice(None), slice(-section_latent_frames, None)) - real_history_latents = history_latents[:, :, :total_generated_latent_frames] + else: + assert False if history_video is None: if not output_type == "latent": @@ -976,16 +1065,18 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix history_video = [real_history_latents] else: if not output_type == "latent": - section_latent_frames = ( - (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) - ) overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 current_latents = ( - real_history_latents[:, :, :section_latent_frames].to(vae_dtype) - / self.vae.config.scaling_factor + real_history_latents[index_slice].to(vae_dtype) / self.vae.config.scaling_factor ) current_video = self.vae.decode(current_latents, return_dict=False)[0] - history_video = self._soft_append(current_video, history_video, overlapped_frames) + + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + history_video = self._soft_append(current_video, history_video, overlapped_frames) + elif sampling_type == FramepackSamplingType.VANILLA: + history_video = self._soft_append(history_video, current_video, overlapped_frames) + else: + assert False else: history_video.append(real_history_latents) diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 199e730d9b..6001867916 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -22,9 +22,11 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"] _import_structure["pipeline_ltx"] = ["LTXPipeline"] _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] + _import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,9 +36,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .modeling_latent_upsampler import LTXLatentUpsamplerModel from .pipeline_ltx import LTXPipeline from .pipeline_ltx_condition import LTXConditionPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline + from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline else: import sys diff --git a/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py b/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py new file mode 100644 index 0000000000..6dce792a2b --- /dev/null +++ b/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py @@ -0,0 +1,188 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index dcfdfaf232..50ef7ae7d1 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -430,6 +430,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL video, frame_index, strength, + denoise_strength, height, width, callback_on_step_end_tensor_inputs=None, @@ -497,6 +498,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") + if denoise_strength < 0 or denoise_strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}") + @staticmethod def _prepare_video_ids( batch_size: int, @@ -649,6 +653,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL width: int = 704, num_frames: int = 161, num_prefix_latent_frames: int = 2, + sigma: Optional[torch.Tensor] = None, + latents: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -658,7 +664,18 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL latent_width = width // self.vae_spatial_compression_ratio shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if latents is not None and sigma is not None: + if latents.shape != shape: + raise ValueError( + f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input." + ) + latents = latents.to(device=device, dtype=dtype) + sigma = sigma.to(device=device, dtype=dtype) + latents = sigma * noise + (1 - sigma) * latents + else: + latents = noise if len(conditions) > 0: condition_latent_frames_mask = torch.zeros( @@ -766,6 +783,13 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL return latents, conditioning_mask, video_ids, extra_conditioning_num_latents + def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength): + num_steps = min(int(num_inference_steps * strength), num_inference_steps) + start_index = max(num_inference_steps - num_steps, 0) + sigmas = sigmas[start_index:] + timesteps = timesteps[start_index:] + return sigmas, timesteps, num_inference_steps - start_index + @property def guidance_scale(self): return self._guidance_scale @@ -799,6 +823,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL video: List[PipelineImageInput] = None, frame_index: Union[int, List[int]] = 0, strength: Union[float, List[float]] = 1.0, + denoise_strength: float = 1.0, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, @@ -842,6 +867,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL generation. If not provided, one has to pass `conditions`. strength (`float` or `List[float]`, *optional*): The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. + denoise_strength (`float`, defaults to `1.0`): + The strength of the noise added to the latents for editing. Higher strength leads to more noise added + to the latents, therefore leading to more differences between original video and generated video. This + is useful for video-to-video editing. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -918,8 +947,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - if latents is not None: - raise ValueError("Passing latents is not yet supported.") # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -929,6 +956,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL video=video, frame_index=frame_index, strength=strength, + denoise_strength=denoise_strength, height=height, width=width, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -977,8 +1005,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL strength = [strength] * num_conditions device = self._execution_device + vae_dtype = self.vae.dtype - # 3. Prepare text embeddings + # 3. Prepare text embeddings & conditioning image/video ( prompt_embeds, prompt_attention_mask, @@ -1000,8 +1029,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - vae_dtype = self.vae.dtype - conditioning_tensors = [] is_conditioning_image_or_video = image is not None or video is not None if is_conditioning_image_or_video: @@ -1032,7 +1059,25 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ) conditioning_tensors.append(condition_tensor) - # 4. Prepare latent variables + # 4. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + sigmas = linear_quadratic_schedule(num_inference_steps) + timesteps = sigmas * 1000 + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + latent_sigma = None + if denoise_strength < 1: + sigmas, timesteps, num_inference_steps = self.get_timesteps( + sigmas, timesteps, num_inference_steps, denoise_strength + ) + latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt) + + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( conditioning_tensors, @@ -1043,6 +1088,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL height=height, width=width, num_frames=num_frames, + sigma=latent_sigma, + latents=latents, generator=generator, device=device, dtype=torch.float32, @@ -1056,21 +1103,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL if self.do_classifier_free_guidance: video_coords = torch.cat([video_coords, video_coords], dim=0) - # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - sigmas = linear_quadratic_schedule(num_inference_steps) - timesteps = sigmas * 1000 - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps=timesteps, - ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1168,7 +1200,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL if not self.vae.config.timestep_conditioning: timestep = None else: - noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) if not isinstance(decode_timestep, list): decode_timestep = [decode_timestep] * batch_size if decode_noise_scale is None: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py new file mode 100644 index 0000000000..a90e52e717 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py @@ -0,0 +1,241 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLLTXVideo +from ...utils import get_logger +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .modeling_latent_upsampler import LTXLatentUpsamplerModel +from .pipeline_output import LTXPipelineOutput + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTXLatentUpsamplePipeline(DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKLLTXVideo, + latent_upsampler: LTXLatentUpsamplerModel, + ) -> None: + super().__init__() + + self.register_modules(vae=vae, latent_upsampler=latent_upsampler) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + return init_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def check_inputs(self, video, height, width, latents): + if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` can be provided.") + if video is None and latents is None: + raise ValueError("One of `video` or `latents` has to be provided.") + + @torch.no_grad() + def __call__( + self, + video: Optional[List[PipelineImageInput]] = None, + height: int = 512, + width: int = 704, + latents: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + self.check_inputs( + video=video, + height=height, + width=width, + latents=latents, + ) + + if video is not None: + # Batched video input is not yet tested/supported. TODO: take a look later + batch_size = 1 + else: + batch_size = latents.shape[0] + device = self._execution_device + + if video is not None: + num_frames = len(video) + if num_frames % self.vae_temporal_compression_ratio != 1: + num_frames = ( + num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1 + ) + video = video[:num_frames] + logger.warning( + f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames." + ) + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=torch.float32) + + latents = self.prepare_latents( + video=video, + batch_size=batch_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.latent_upsampler.dtype) + latents = self.latent_upsampler(latents) + + if output_type == "latent": + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = latents + else: + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index 54ab7d19e3..7c5ac89602 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -248,9 +248,8 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved using [`~FlaxDiffusionPipeline.save_pretrained`]. - dtype (`str` or `jnp.dtype`, *optional*): - Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is - automatically derived from the model's weights. + dtype (`jnp.dtype`, *optional*): + Override the default `jnp.dtype` and load the model under this dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 7cb2a12d3c..4348143b8e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -573,12 +573,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): saved using [`~DiffusionPipeline.save_pretrained`]. - A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file - torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the - dtype is automatically derived from the model's weights. To load submodels with different dtype pass a - `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for - unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default': - torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used. + torch_dtype (`torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. To load submodels with + different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). + Set the default dtype for unspecified components with `default` (for example `{'transformer': + torch.bfloat16, 'default': torch.float16}`). If a component is not specified and no default is set, + `torch.float32` is used. custom_pipeline (`str`, *optional*): diff --git a/src/diffusers/pipelines/visualcloze/__init__.py b/src/diffusers/pipelines/visualcloze/__init__.py new file mode 100644 index 0000000000..ab765a1bba --- /dev/null +++ b/src/diffusers/pipelines/visualcloze/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_visualcloze_combined"] = ["VisualClozePipeline"] + _import_structure["pipeline_visualcloze_generation"] = ["VisualClozeGenerationPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_visualcloze_combined import VisualClozePipeline + from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py new file mode 100644 index 0000000000..0d9114a1a5 --- /dev/null +++ b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py @@ -0,0 +1,444 @@ +# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from PIL import Image +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ..flux.pipeline_flux_fill import FluxFillPipeline as VisualClozeUpsamplingPipeline +from ..flux.pipeline_output import FluxPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline + + +if is_torch_xla_available(): + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import VisualClozePipeline + >>> from diffusers.utils import load_image + + >>> image_paths = [ + ... # in-context examples + ... [ + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg" + ... ), + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg" + ... ), + ... ], + ... # query with the target image + ... [ + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg" + ... ), + ... None, # No image needed for the target image + ... ], + ... ] + >>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding." + >>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography." + >>> pipe = VisualClozePipeline.from_pretrained( + ... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... task_prompt=task_prompt, + ... content_prompt=content_prompt, + ... image=image_paths, + ... upsampling_width=1344, + ... upsampling_height=768, + ... upsampling_strength=0.4, + ... guidance_scale=30, + ... num_inference_steps=30, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0][0] + >>> image.save("visualcloze.png") + ``` +""" + + +class VisualClozePipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The VisualCloze pipeline for image generation with visual context. Reference: + https://github.com/lzyhha/VisualCloze/tree/main. This pipeline is designed to generate images based on visual + in-context examples. + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + resolution (`int`, *optional*, defaults to 384): + The resolution of each image when concatenating images from the query and in-context examples. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + resolution: int = 384, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + + self.generation_pipe = VisualClozeGenerationPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + resolution=resolution, + ) + self.upsampling_pipe = VisualClozeUpsamplingPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + + def check_inputs( + self, + image, + task_prompt, + content_prompt, + upsampling_height, + upsampling_width, + strength, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if upsampling_height is not None and upsampling_height % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`upsampling_height`has to be divisible by {self.vae_scale_factor * 2} but are {upsampling_height}. Dimensions will be resized accordingly" + ) + if upsampling_width is not None and upsampling_width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`upsampling_width` have to be divisible by {self.vae_scale_factor * 2} but are {upsampling_width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Validate prompt inputs + if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None: + raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ") + + if task_prompt is None and content_prompt is None and prompt_embeds is None: + raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ") + + # Validate prompt types and consistency + if task_prompt is None: + raise ValueError("`task_prompt` is missing.") + + if task_prompt is not None and not isinstance(task_prompt, (str, list)): + raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}") + + if content_prompt is not None and not isinstance(content_prompt, (str, list)): + raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}") + + if isinstance(task_prompt, list) or isinstance(content_prompt, list): + if not isinstance(task_prompt, list) or not isinstance(content_prompt, list): + raise ValueError( + f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, " + f"got {type(task_prompt)} and {type(content_prompt)}" + ) + if len(content_prompt) != len(task_prompt): + raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.") + + for sample in image: + if not isinstance(sample, list) or not isinstance(sample[0], list): + raise ValueError("Each sample in the batch must have a 2D list of images.") + if len({len(row) for row in sample}) != 1: + raise ValueError("Each in-context example and query should contain the same number of images.") + if not any(img is None for img in sample[-1]): + raise ValueError("There are no targets in the query, which should be represented as None.") + for row in sample[:-1]: + if any(img is None for img in row): + raise ValueError("Images are missing in in-context examples.") + + # Validate embeddings + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + # Validate sequence length + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + task_prompt: Union[str, List[str]] = None, + content_prompt: Union[str, List[str]] = None, + image: Optional[torch.FloatTensor] = None, + upsampling_height: Optional[int] = None, + upsampling_width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 30.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + upsampling_strength: float = 1.0, + ): + r""" + Function invoked when calling the VisualCloze pipeline for generation. + + Args: + task_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the task intention. + content_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the content or caption of the target image to be generated. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + upsampling_height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By + default, the image is upsampled by a factor of three, and the base resolution is determined by the + resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is + specified, the other will be automatically set based on the aspect ratio. + upsampling_width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By + default, the image is upsampled by a factor of three, and the base resolution is determined by the + resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is + specified, the other will be automatically set based on the aspect ratio. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 30.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + upsampling_strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image` when upsampling the results. Must be between 0 and + 1. The generated image is used as a starting point and more noise is added the higher the + `upsampling_strength`. The number of denoising steps depends on the amount of noise initially added. + When `upsampling_strength` is 1, added noise is maximum and the denoising process runs for the full + number of iterations specified in `num_inference_steps`. A value of 0 skips the upsampling step and + output the results at the resolution of `self.resolution`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + generation_output = self.generation_pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image, + num_inference_steps=num_inference_steps, + sigmas=sigmas, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + joint_attention_kwargs=joint_attention_kwargs, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + output_type=output_type if upsampling_strength == 0 else "pil", + ) + if upsampling_strength == 0: + if not return_dict: + return (generation_output,) + + return FluxPipelineOutput(images=generation_output) + + # Upsampling the generated images + # 1. Prepare the input images and prompts + if not isinstance(content_prompt, (list)): + content_prompt = [content_prompt] + n_target_per_sample = [] + upsampling_image = [] + upsampling_mask = [] + upsampling_prompt = [] + upsampling_generator = generator if isinstance(generator, (torch.Generator,)) else [] + for i in range(len(generation_output.images)): + n_target_per_sample.append(len(generation_output.images[i])) + for image in generation_output.images[i]: + upsampling_image.append(image) + upsampling_mask.append(Image.new("RGB", image.size, (255, 255, 255))) + upsampling_prompt.append( + content_prompt[i % len(content_prompt)] if content_prompt[i % len(content_prompt)] else "" + ) + if not isinstance(generator, (torch.Generator,)): + upsampling_generator.append(generator[i % len(content_prompt)]) + + # 2. Apply the denosing loop + upsampling_output = self.upsampling_pipe( + prompt=upsampling_prompt, + image=upsampling_image, + mask_image=upsampling_mask, + height=upsampling_height, + width=upsampling_width, + strength=upsampling_strength, + num_inference_steps=num_inference_steps, + sigmas=sigmas, + guidance_scale=guidance_scale, + generator=upsampling_generator, + output_type=output_type, + joint_attention_kwargs=joint_attention_kwargs, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + image = upsampling_output.images + + output = [] + if output_type == "pil": + # Each sample in the batch may have multiple output images. When returning as PIL images, + # these images cannot be concatenated. Therefore, for each sample, + # a list is used to represent all the output images. + output = [] + start = 0 + for n in n_target_per_sample: + output.append(image[start : start + n]) + start += n + else: + output = image + + if not return_dict: + return (output,) + + return FluxPipelineOutput(images=output) diff --git a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py new file mode 100644 index 0000000000..48d03766ae --- /dev/null +++ b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py @@ -0,0 +1,952 @@ +# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..flux.pipeline_flux_fill import calculate_shift, retrieve_latents, retrieve_timesteps +from ..flux.pipeline_output import FluxPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .visualcloze_utils import VisualClozeProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline + >>> from diffusers.utils import load_image + >>> from PIL import Image + + >>> image_paths = [ + ... # in-context examples + ... [ + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg" + ... ), + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg" + ... ), + ... ], + ... # query with the target image + ... [ + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg" + ... ), + ... None, # No image needed for the target image + ... ], + ... ] + >>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding." + >>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography." + >>> pipe = VisualClozeGenerationPipeline.from_pretrained( + ... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... task_prompt=task_prompt, + ... content_prompt=content_prompt, + ... image=image_paths, + ... guidance_scale=30, + ... num_inference_steps=30, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0][0] + + >>> # optional, upsampling the generated image + >>> pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe) + >>> pipe_upsample.to("cuda") + + >>> mask_image = Image.new("RGB", image.size, (255, 255, 255)) + + >>> image = pipe_upsample( + ... image=image, + ... mask_image=mask_image, + ... prompt=content_prompt, + ... width=1344, + ... height=768, + ... strength=0.4, + ... guidance_scale=30, + ... num_inference_steps=30, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0] + + >>> image.save("visualcloze.png") + ``` +""" + + +class VisualClozeGenerationPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The VisualCloze pipeline for image generation with visual context. Reference: + https://github.com/lzyhha/VisualCloze/tree/main This pipeline is designed to generate images based on visual + in-context examples. + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + resolution (`int`, *optional*, defaults to 384): + The resolution of each image when concatenating images from the query and in-context examples. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + resolution: int = 384, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.resolution = resolution + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VisualClozeProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels, resolution=resolution + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Modified from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + layout_prompt: Union[str, List[str]], + task_prompt: Union[str, List[str]], + content_prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + layout_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the number of in-context examples and the number of images involved in + the task. + task_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the task intention. + content_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the content or caption of the target image to be generated. + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + if isinstance(layout_prompt, str): + layout_prompt = [layout_prompt] + task_prompt = [task_prompt] + content_prompt = [content_prompt] + + def _preprocess(prompt, content=False): + if prompt is not None: + return f"The last image of the last row depicts: {prompt}" if content else prompt + else: + return "" + + prompt = [ + f"{_preprocess(layout_prompt[i])} {_preprocess(task_prompt[i])} {_preprocess(content_prompt[i], content=True)}".strip() + for i in range(len(layout_prompt)) + ] + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + image, + task_prompt, + content_prompt, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Validate prompt inputs + if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None: + raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ") + + if task_prompt is None and content_prompt is None and prompt_embeds is None: + raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ") + + # Validate prompt types and consistency + if task_prompt is None: + raise ValueError("`task_prompt` is missing.") + + if task_prompt is not None and not isinstance(task_prompt, (str, list)): + raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}") + + if content_prompt is not None and not isinstance(content_prompt, (str, list)): + raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}") + + if isinstance(task_prompt, list) or isinstance(content_prompt, list): + if not isinstance(task_prompt, list) or not isinstance(content_prompt, list): + raise ValueError( + f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, " + f"got {type(task_prompt)} and {type(content_prompt)}" + ) + if len(content_prompt) != len(task_prompt): + raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.") + + for sample in image: + if not isinstance(sample, list) or not isinstance(sample[0], list): + raise ValueError("Each sample in the batch must have a 2D list of images.") + if len({len(row) for row in sample}) != 1: + raise ValueError("Each in-context example and query should contain the same number of images.") + if not any(img is None for img in sample[-1]): + raise ValueError("There are no targets in the query, which should be represented as None.") + for row in sample[:-1]: + if any(img is None for img in row): + raise ValueError("Images are missing in in-context examples.") + + # Validate embeddings + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + # Validate sequence length + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(image, vae_scale_factor, device, dtype): + latent_image_ids = [] + + for idx, img in enumerate(image, start=1): + img = img.squeeze(0) + channels, height, width = img.shape + + num_patches_h = height // vae_scale_factor // 2 + num_patches_w = width // vae_scale_factor // 2 + + patch_ids = torch.zeros(num_patches_h, num_patches_w, 3, device=device, dtype=dtype) + patch_ids[..., 0] = idx + patch_ids[..., 1] = torch.arange(num_patches_h, device=device, dtype=dtype)[:, None] + patch_ids[..., 2] = torch.arange(num_patches_w, device=device, dtype=dtype)[None, :] + + patch_ids = patch_ids.reshape(-1, 3) + latent_image_ids.append(patch_ids) + + return torch.cat(latent_image_ids, dim=0) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, sizes, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + start = 0 + unpacked_latents = [] + for i in range(len(sizes)): + cur_size = sizes[i] + height = cur_size[0][0] // vae_scale_factor + width = sum([size[1] for size in cur_size]) // vae_scale_factor + + end = start + (height * width) // 4 + + cur_latents = latents[:, start:end] + cur_latents = cur_latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + cur_latents = cur_latents.permute(0, 3, 1, 4, 2, 5) + cur_latents = cur_latents.reshape(batch_size, channels // (2 * 2), height, width) + + unpacked_latents.append(cur_latents) + + start = end + + return unpacked_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _prepare_latents(self, image, mask, gen, vae_scale_factor, device, dtype): + """Helper function to prepare latents for a single batch.""" + # Concatenate images and masks along width dimension + image = [torch.cat(img, dim=3).to(device=device, dtype=dtype) for img in image] + mask = [torch.cat(m, dim=3).to(device=device, dtype=dtype) for m in mask] + + # Generate latent image IDs + latent_image_ids = self._prepare_latent_image_ids(image, vae_scale_factor, device, dtype) + + # For initial encoding, use actual images + image_latent = [self._encode_vae_image(img, gen) for img in image] + masked_image_latent = [img.clone() for img in image_latent] + + for i in range(len(image_latent)): + # Rearrange latents and masks for patch processing + num_channels_latents, height, width = image_latent[i].shape[1:] + image_latent[i] = self._pack_latents(image_latent[i], 1, num_channels_latents, height, width) + masked_image_latent[i] = self._pack_latents(masked_image_latent[i], 1, num_channels_latents, height, width) + + # Rearrange masks for patch processing + num_channels_latents, height, width = mask[i].shape[1:] + mask[i] = mask[i].view( + 1, + num_channels_latents, + height // vae_scale_factor, + vae_scale_factor, + width // vae_scale_factor, + vae_scale_factor, + ) + mask[i] = mask[i].permute(0, 1, 3, 5, 2, 4) + mask[i] = mask[i].reshape( + 1, + num_channels_latents * (vae_scale_factor**2), + height // vae_scale_factor, + width // vae_scale_factor, + ) + mask[i] = self._pack_latents( + mask[i], + 1, + num_channels_latents * (vae_scale_factor**2), + height // vae_scale_factor, + width // vae_scale_factor, + ) + + # Concatenate along batch dimension + image_latent = torch.cat(image_latent, dim=1) + masked_image_latent = torch.cat(masked_image_latent, dim=1) + mask = torch.cat(mask, dim=1) + + return image_latent, masked_image_latent, mask, latent_image_ids + + def prepare_latents( + self, + input_image, + input_mask, + timestep, + batch_size, + dtype, + device, + generator, + vae_scale_factor, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Process each batch + masked_image_latents = [] + image_latents = [] + masks = [] + latent_image_ids = [] + + for i in range(len(input_image)): + _image_latent, _masked_image_latent, _mask, _latent_image_ids = self._prepare_latents( + input_image[i], + input_mask[i], + generator if isinstance(generator, torch.Generator) else generator[i], + vae_scale_factor, + device, + dtype, + ) + masked_image_latents.append(_masked_image_latent) + image_latents.append(_image_latent) + masks.append(_mask) + latent_image_ids.append(_latent_image_ids) + + # Concatenate all batches + masked_image_latents = torch.cat(masked_image_latents, dim=0) + image_latents = torch.cat(image_latents, dim=0) + masks = torch.cat(masks, dim=0) + + # Handle batch size expansion + if batch_size > masked_image_latents.shape[0]: + if batch_size % masked_image_latents.shape[0] == 0: + # Expand batches by repeating + additional_image_per_prompt = batch_size // masked_image_latents.shape[0] + masked_image_latents = torch.cat([masked_image_latents] * additional_image_per_prompt, dim=0) + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + masks = torch.cat([masks] * additional_image_per_prompt, dim=0) + else: + raise ValueError( + f"Cannot expand batch size from {masked_image_latents.shape[0]} to {batch_size}. " + "Batch sizes must be multiples of each other." + ) + + # Add noise to latents + noises = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noises).to(dtype=dtype) + + # Combine masked latents with masks + masked_image_latents = torch.cat((masked_image_latents, masks), dim=-1).to(dtype=dtype) + + return latents, masked_image_latents, latent_image_ids[0] + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + task_prompt: Union[str, List[str]] = None, + content_prompt: Union[str, List[str]] = None, + image: Optional[torch.FloatTensor] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 30.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the VisualCloze pipeline for generation. + + Args: + task_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the task intention. + content_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to define the content or caption of the target image to be generated. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 30.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image, + task_prompt, + content_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + processor_output = self.image_processor.preprocess( + task_prompt, content_prompt, image, vae_scale_factor=self.vae_scale_factor + ) + + # 2. Define call parameters + if processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], str): + batch_size = 1 + elif processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], list): + batch_size = len(processor_output["task_prompt"]) + + device = self._execution_device + + # 3. Prepare prompt embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + layout_prompt=processor_output["layout_prompt"], + task_prompt=processor_output["task_prompt"], + content_prompt=processor_output["content_prompt"], + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + # Calculate sequence length and shift factor + image_seq_len = sum( + (size[0] // self.vae_scale_factor // 2) * (size[1] // self.vae_scale_factor // 2) + for sample in processor_output["image_size"][0] + for size in sample + ) + + # Calculate noise schedule parameters + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + # Get timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device) + + # 5. Prepare latent variables + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents, masked_image_latents, latent_image_ids = self.prepare_latents( + processor_output["init_image"], + processor_output["mask"], + latent_timestep, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + vae_scale_factor=self.vae_scale_factor, + ) + + # Calculate warmup steps + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Prepare guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + latent_model_input = torch.cat((latents, masked_image_latents), dim=2) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # Compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # Some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # XLA optimization + if XLA_AVAILABLE: + xm.mark_step() + + # 7. Post-process the image + # Crop the target image + # Since the generated image is a concatenation of the conditional and target regions, + # we need to extract only the target regions based on their positions + image = [] + if output_type == "latent": + image = latents + else: + for b in range(len(latents)): + cur_image_size = processor_output["image_size"][b % batch_size] + cur_target_position = processor_output["target_position"][b % batch_size] + cur_latent = self._unpack_latents(latents[b].unsqueeze(0), cur_image_size, self.vae_scale_factor)[-1] + cur_latent = (cur_latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor + cur_image = self.vae.decode(cur_latent, return_dict=False)[0] + cur_image = self.image_processor.postprocess(cur_image, output_type=output_type)[0] + + start = 0 + cropped = [] + for i, size in enumerate(cur_image_size[-1]): + if cur_target_position[i]: + if output_type == "pil": + cropped.append(cur_image.crop((start, 0, start + size[1], size[0]))) + else: + cropped.append(cur_image[0 : size[0], start : start + size[1]]) + start += size[1] + image.append(cropped) + if output_type != "pil": + image = np.concatenate([arr[None] for sub_image in image for arr in sub_image], axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/visualcloze/visualcloze_utils.py b/src/diffusers/pipelines/visualcloze/visualcloze_utils.py new file mode 100644 index 0000000000..5d221bc1e8 --- /dev/null +++ b/src/diffusers/pipelines/visualcloze/visualcloze_utils.py @@ -0,0 +1,251 @@ +# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from PIL import Image + +from ...image_processor import VaeImageProcessor + + +class VisualClozeProcessor(VaeImageProcessor): + """ + Image processor for the VisualCloze pipeline. + + This processor handles the preprocessing of images for visual cloze tasks, including resizing, normalization, and + mask generation. + + Args: + resolution (int, optional): + Target resolution for processing images. Each image will be resized to this resolution before being + concatenated to avoid the out-of-memory error. Defaults to 384. + *args: Additional arguments passed to [~image_processor.VaeImageProcessor] + **kwargs: Additional keyword arguments passed to [~image_processor.VaeImageProcessor] + """ + + def __init__(self, *args, resolution: int = 384, **kwargs): + super().__init__(*args, **kwargs) + self.resolution = resolution + + def preprocess_image( + self, input_images: List[List[Optional[Image.Image]]], vae_scale_factor: int + ) -> Tuple[List[List[torch.Tensor]], List[List[List[int]]], List[int]]: + """ + Preprocesses input images for the VisualCloze pipeline. + + This function handles the preprocessing of input images by: + 1. Resizing and cropping images to maintain consistent dimensions + 2. Converting images to the Tensor format for the VAE + 3. Normalizing pixel values + 4. Tracking image sizes and positions of target images + + Args: + input_images (List[List[Optional[Image.Image]]]): + A nested list of PIL Images where: + - Outer list represents different samples, including in-context examples and the query + - Inner list contains images for the task + - In the last row, condition images are provided and the target images are placed as None + vae_scale_factor (int): + The scale factor used by the VAE for resizing images + + Returns: + Tuple containing: + - List[List[torch.Tensor]]: Preprocessed images in tensor format + - List[List[List[int]]]: Dimensions of each processed image [height, width] + - List[int]: Target positions indicating which images are to be generated + """ + n_samples, n_task_images = len(input_images), len(input_images[0]) + divisible = 2 * vae_scale_factor + + processed_images: List[List[Image.Image]] = [[] for _ in range(n_samples)] + resize_size: List[Optional[Tuple[int, int]]] = [None for _ in range(n_samples)] + target_position: List[int] = [] + + # Process each sample + for i in range(n_samples): + # Determine size from first non-None image + for j in range(n_task_images): + if input_images[i][j] is not None: + aspect_ratio = input_images[i][j].width / input_images[i][j].height + target_area = self.resolution * self.resolution + new_h = int((target_area / aspect_ratio) ** 0.5) + new_w = int(new_h * aspect_ratio) + + new_w = max(new_w // divisible, 1) * divisible + new_h = max(new_h // divisible, 1) * divisible + resize_size[i] = (new_w, new_h) + break + + # Process all images in the sample + for j in range(n_task_images): + if input_images[i][j] is not None: + target = self._resize_and_crop(input_images[i][j], resize_size[i][0], resize_size[i][1]) + processed_images[i].append(target) + if i == n_samples - 1: + target_position.append(0) + else: + blank = Image.new("RGB", resize_size[i] or (self.resolution, self.resolution), (0, 0, 0)) + processed_images[i].append(blank) + if i == n_samples - 1: + target_position.append(1) + + # Ensure consistent width for multiple target images when there are multiple target images + if len(target_position) > 1 and sum(target_position) > 1: + new_w = resize_size[n_samples - 1][0] or 384 + for i in range(len(processed_images)): + for j in range(len(processed_images[i])): + if processed_images[i][j] is not None: + new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width)) + new_w = int(new_w / 16) * 16 + new_h = int(new_h / 16) * 16 + processed_images[i][j] = self.height(processed_images[i][j], new_h, new_w) + + # Convert to tensors and normalize + image_sizes = [] + for i in range(len(processed_images)): + image_sizes.append([[img.height, img.width] for img in processed_images[i]]) + for j, image in enumerate(processed_images[i]): + image = self.pil_to_numpy(image) + image = self.numpy_to_pt(image) + image = self.normalize(image) + processed_images[i][j] = image + + return processed_images, image_sizes, target_position + + def preprocess_mask( + self, input_images: List[List[Image.Image]], target_position: List[int] + ) -> List[List[torch.Tensor]]: + """ + Generate masks for the VisualCloze pipeline. + + Args: + input_images (List[List[Image.Image]]): + Processed images from preprocess_image + target_position (List[int]): + Binary list marking the positions of target images (1 for target, 0 for condition) + + Returns: + List[List[torch.Tensor]]: + A nested list of mask tensors (1 for target positions, 0 for condition images) + """ + mask = [] + for i, row in enumerate(input_images): + if i == len(input_images) - 1: # Query row + row_masks = [ + torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=m) for m in target_position + ] + else: # In-context examples + row_masks = [ + torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=0) for _ in target_position + ] + mask.append(row_masks) + return mask + + def preprocess_image_upsampling( + self, + input_images: List[List[Image.Image]], + height: int, + width: int, + ) -> Tuple[List[List[Image.Image]], List[List[List[int]]]]: + """Process images for the upsampling stage in the VisualCloze pipeline. + + Args: + input_images: Input image to process + height: Target height + width: Target width + + Returns: + Tuple of processed image and its size + """ + image = self.resize(input_images[0][0], height, width) + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + image = self.normalize(image) + + input_images[0][0] = image + image_sizes = [[[height, width]]] + return input_images, image_sizes + + def preprocess_mask_upsampling(self, input_images: List[List[Image.Image]]) -> List[List[torch.Tensor]]: + return [[torch.ones((1, 1, input_images[0][0].shape[2], input_images[0][0].shape[3]))]] + + def get_layout_prompt(self, size: Tuple[int, int]) -> str: + layout_instruction = ( + f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.", + ) + return layout_instruction + + def preprocess( + self, + task_prompt: Union[str, List[str]], + content_prompt: Union[str, List[str]], + input_images: Optional[List[List[List[Optional[str]]]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + upsampling: bool = False, + vae_scale_factor: int = 16, + ) -> Dict: + """Process visual cloze inputs. + + Args: + task_prompt: Task description(s) + content_prompt: Content description(s) + input_images: List of images or None for the target images + height: Optional target height for upsampling stage + width: Optional target width for upsampling stage + upsampling: Whether this is in the upsampling processing stage + + Returns: + Dictionary containing processed images, masks, prompts and metadata + """ + if isinstance(task_prompt, str): + task_prompt = [task_prompt] + content_prompt = [content_prompt] + input_images = [input_images] + + output = { + "init_image": [], + "mask": [], + "task_prompt": task_prompt if not upsampling else [None for _ in range(len(task_prompt))], + "content_prompt": content_prompt, + "layout_prompt": [], + "target_position": [], + "image_size": [], + } + for i in range(len(task_prompt)): + if upsampling: + layout_prompt = None + else: + layout_prompt = self.get_layout_prompt((len(input_images[i]), len(input_images[i][0]))) + + if upsampling: + cur_processed_images, cur_image_size = self.preprocess_image_upsampling( + input_images[i], height=height, width=width + ) + cur_mask = self.preprocess_mask_upsampling(cur_processed_images) + else: + cur_processed_images, cur_image_size, cur_target_position = self.preprocess_image( + input_images[i], vae_scale_factor=vae_scale_factor + ) + cur_mask = self.preprocess_mask(cur_processed_images, cur_target_position) + + output["target_position"].append(cur_target_position) + + output["image_size"].append(cur_image_size) + output["init_image"].append(cur_processed_images) + output["mask"].append(cur_mask) + output["layout_prompt"].append(layout_prompt) + + return output diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c36c758d87..4ab6091c6d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1307,6 +1307,21 @@ class LTXImageToVideoPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class LTXLatentUpsamplePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2792,6 +2807,36 @@ class VideoToVideoSDPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class VisualClozeGenerationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VisualClozePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class VQDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_hidream.py b/tests/models/transformers/test_models_transformer_hidream.py new file mode 100644 index 0000000000..fa0fa5123a --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hidream.py @@ -0,0 +1,96 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HiDreamImageTransformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = HiDreamImageTransformer2DModel + main_input_name = "hidden_states" + model_split_percents = [0.8, 0.8, 0.9] + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = width = 32 + embedding_dim_t5, embedding_dim_llama, embedding_dim_pooled = 8, 4, 8 + sequence_length = 8 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length, embedding_dim_t5)).to(torch_device) + encoder_hidden_states_llama3 = torch.randn((batch_size, batch_size, sequence_length, embedding_dim_llama)).to( + torch_device + ) + pooled_embeds = torch.randn((batch_size, embedding_dim_pooled)).to(torch_device) + timesteps = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states_t5": encoder_hidden_states_t5, + "encoder_hidden_states_llama3": encoder_hidden_states_llama3, + "pooled_embeds": pooled_embeds, + "timesteps": timesteps, + } + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 2, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_channels": [8, 4], + "text_emb_dim": 8, + "num_routed_experts": 2, + "num_activated_experts": 2, + "axes_dims_rope": (4, 2, 2), + "max_resolution": (32, 32), + "llama_layers": (0, 1), + "force_inference_output": True, # TODO: as we don't implement MoE loss in training tests. + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skip("HiDreamImageTransformer2DModel uses a dedicated attention processor. This test doesn't apply") + def test_set_attn_processor_for_determinism(self): + pass + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HiDreamImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index 128bf04155..8649ce97a5 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -20,13 +20,13 @@ import torch from diffusers import LTXVideoTransformer3DModel from diffusers.utils.testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class LTXTransformerTests(ModelTesterMixin, unittest.TestCase): +class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): model_class = LTXVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 94a5d641a7..24d944bbf9 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -53,7 +53,12 @@ from diffusers.utils.testing_utils import ( torch_device, ) -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ( + LoraHotSwappingForModelTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + UNetTesterMixin, +) if is_peft_available(): @@ -351,7 +356,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): class UNet2DConditionModelTests( - ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase + ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase ): model_class = UNet2DConditionModel main_input_name = "sample" diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py index 1f60c0b421..881946e6a0 100644 --- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -20,7 +20,14 @@ import numpy as np import torch from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel -from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + nightly, + require_torch_accelerator, + skip_mps, + torch_device, +) from ..pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -116,19 +123,19 @@ class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class PipelineIntegrationTests(unittest.TestCase): def setUp(self): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_dance_diffusion(self): device = torch_device diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py index 10a95d6177..6454152b7a 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py @@ -28,13 +28,15 @@ from diffusers import ( VQModel, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_image, load_numpy, nightly, numpy_cosine_similarity_distance, - require_torch_gpu, + require_torch_accelerator, + torch_device, ) from ..test_pipelines_common import PipelineTesterMixin @@ -226,19 +228,19 @@ class KandinskyV22ControlnetPipelineFastTests(PipelineTesterMixin, unittest.Test @nightly -@require_torch_gpu +@require_torch_accelerator class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase): def setUp(self): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_kandinsky_controlnet(self): expected_image = load_numpy( diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py index 58fbbecc05..c99b7b738a 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py @@ -29,13 +29,15 @@ from diffusers import ( VQModel, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_image, load_numpy, nightly, numpy_cosine_similarity_distance, - require_torch_gpu, + require_torch_accelerator, + torch_device, ) from ..test_pipelines_common import PipelineTesterMixin @@ -233,19 +235,19 @@ class KandinskyV22ControlnetImg2ImgPipelineFastTests(PipelineTesterMixin, unitte @nightly -@require_torch_gpu +@require_torch_accelerator class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase): def setUp(self): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_kandinsky_controlnet_img2img(self): expected_image = load_numpy( @@ -309,4 +311,4 @@ class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase): assert image.shape == (512, 512, 3) max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) - assert max_diff < 1e-4 + assert max_diff < 5e-4 diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py index e751240e43..245116d5fa 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -22,10 +22,11 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, load_numpy, nightly, - require_torch_gpu, + require_torch_accelerator, torch_device, ) @@ -136,17 +137,17 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class LDMTextToImagePipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, dtype=torch.float32, seed=0): generator = torch.manual_seed(seed) @@ -177,17 +178,17 @@ class LDMTextToImagePipelineSlowTests(unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class LDMTextToImagePipelineNightlyTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, dtype=torch.float32, seed=0): generator = torch.manual_seed(seed) diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py index 1c3e018a8a..6c425a2e3f 100644 --- a/tests/pipelines/ltx/test_ltx_image2video.py +++ b/tests/pipelines/ltx/test_ltx_image2video.py @@ -109,7 +109,7 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): else: generator = torch.Generator(device=device).manual_seed(seed) - image = torch.randn((1, 3, 32, 32), generator=generator, device=device) + image = torch.rand((1, 3, 32, 32), generator=generator, device=device) inputs = { "image": image, @@ -142,7 +142,7 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): self.assertEqual(generated_video.shape, (9, 3, 32, 32)) expected_video = torch.randn(9, 3, 32, 32) - max_diff = np.abs(generated_video - expected_video).max() + max_diff = torch.amax(torch.abs(generated_video - expected_video)) self.assertLessEqual(max_diff, 1e10) def test_callback_inputs(self): diff --git a/tests/pipelines/ltx/test_ltx_latent_upsample.py b/tests/pipelines/ltx/test_ltx_latent_upsample.py new file mode 100644 index 0000000000..f9ddb12186 --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_latent_upsample.py @@ -0,0 +1,159 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKLLTXVideo, LTXLatentUpsamplePipeline +from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel +from diffusers.utils.testing_utils import enable_full_determinism + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXLatentUpsamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXLatentUpsamplePipeline + params = {"video", "generator"} + batch_params = {"video", "generator"} + required_optional_params = frozenset(["generator", "latents", "return_dict"]) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + decoder_block_out_channels=(8, 8, 8, 8), + layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + latent_upsampler = LTXLatentUpsamplerModel( + in_channels=8, + mid_channels=32, + num_blocks_per_stage=1, + dims=3, + spatial_upsample=True, + temporal_upsample=False, + ) + + components = { + "vae": vae, + "latent_upsampler": latent_upsampler, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + video = torch.randn((5, 3, 32, 32), generator=generator, device=device) + + inputs = { + "video": video, + "generator": generator, + "height": 16, + "width": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (5, 3, 32, 32)) + expected_video = torch.randn(5, 3, 32, 32) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_vae_tiling(self, expected_diff_max: float = 0.25): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + @unittest.skip("Test is not applicable.") + def test_callback_inputs(self): + pass + + @unittest.skip("Test is not applicable.") + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + pass + + @unittest.skip("Test is not applicable.") + def test_inference_batch_consistent(self): + pass + + @unittest.skip("Test is not applicable.") + def test_inference_batch_single_identical(self): + pass diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py index bdd536b6ff..7f553e919c 100644 --- a/tests/pipelines/musicldm/test_musicldm.py +++ b/tests/pipelines/musicldm/test_musicldm.py @@ -39,7 +39,13 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + nightly, + require_torch_accelerator, + torch_device, +) from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -408,17 +414,17 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class MusicLDMPipelineNightlyTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py index 6cf643fe47..638de7e8cc 100644 --- a/tests/pipelines/shap_e/test_shap_e.py +++ b/tests/pipelines/shap_e/test_shap_e.py @@ -21,7 +21,13 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEPipeline from diffusers.pipelines.shap_e import ShapERenderer -from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import ( + backend_empty_cache, + load_numpy, + nightly, + require_torch_accelerator, + torch_device, +) from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference @@ -222,19 +228,19 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class ShapEPipelineIntegrationTests(unittest.TestCase): def setUp(self): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_shap_e(self): expected_image = load_numpy( diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py index 72eee3e35e..ed0a4d47b6 100644 --- a/tests/pipelines/shap_e/test_shap_e_img2img.py +++ b/tests/pipelines/shap_e/test_shap_e_img2img.py @@ -23,11 +23,12 @@ from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEImg2ImgPipeline from diffusers.pipelines.shap_e import ShapERenderer from diffusers.utils.testing_utils import ( + backend_empty_cache, floats_tensor, load_image, load_numpy, nightly, - require_torch_gpu, + require_torch_accelerator, torch_device, ) @@ -250,19 +251,19 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class ShapEImg2ImgPipelineIntegrationTests(unittest.TestCase): def setUp(self): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_shap_e_img2img(self): input_image = load_image( diff --git a/tests/pipelines/visualcloze/__init__.py b/tests/pipelines/visualcloze/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py b/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py new file mode 100644 index 0000000000..7e2aa25709 --- /dev/null +++ b/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py @@ -0,0 +1,344 @@ +import random +import tempfile +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +import diffusers +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, VisualClozePipeline +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + CaptureLogger, + enable_full_determinism, + floats_tensor, + require_accelerator, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class VisualClozePipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = VisualClozePipeline + params = frozenset( + [ + "task_prompt", + "content_prompt", + "upsampling_height", + "upsampling_width", + "guidance_scale", + "prompt_embeds", + "pooled_prompt_embeds", + "upsampling_strength", + ] + ) + batch_params = frozenset(["task_prompt", "content_prompt", "image"]) + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=12, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=6, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[2, 2, 2], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + "resolution": 32, + } + + def get_dummy_inputs(self, device, seed=0): + # Create example images to simulate the input format required by VisualCloze + context_image = [ + Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8)) + for _ in range(2) + ] + query_image = [ + Image.fromarray( + floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8) + ), + None, + ] + + # Create an image list that conforms to the VisualCloze input format + image = [ + context_image, # In-Context example + query_image, # Query image + ] + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.", + "content_prompt": "A beautiful landscape with mountains and a lake", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "upsampling_height": 32, + "upsampling_width": 32, + "max_sequence_length": 77, + "output_type": "np", + "upsampling_strength": 0.4, + } + return inputs + + def test_visualcloze_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["task_prompt"] = "A different task to perform." + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different + assert max_diff > 1e-6 + + def test_visualcloze_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.generation_pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.generation_pipe.vae_scale_factor * 2) + + inputs.update({"upsampling_height": height, "upsampling_width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=1e-3) + + def test_upsampling_strength(self, expected_min_diff=1e-1): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + # Test different upsampling strengths + inputs["upsampling_strength"] = 0.2 + output_no_upsampling = pipe(**inputs).images[0] + + inputs["upsampling_strength"] = 0.8 + output_full_upsampling = pipe(**inputs).images[0] + + # Different upsampling strengths should produce different outputs + max_diff = np.abs(output_no_upsampling - output_full_upsampling).max() + assert max_diff > expected_min_diff + + def test_different_task_prompts(self, expected_min_diff=1e-1): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_original = pipe(**inputs).images[0] + + inputs["task_prompt"] = "A different task description for image generation" + output_different_task = pipe(**inputs).images[0] + + # Different task prompts should produce different outputs + max_diff = np.abs(output_original - output_different_task).max() + assert max_diff > expected_min_diff + + @unittest.skip( + "Test not applicable because the pipeline being tested is a wrapper pipeline. CFG tests should be done on the inner pipelines." + ) + def test_callback_cfg(self): + pass + + def test_save_load_local(self, expected_max_difference=5e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(diffusers.logging.INFO) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + + with CaptureLogger(logger) as cap_logger: + # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware + # This attribute is not serialized in the config of the pipeline + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32) + + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + for name in pipe_loaded.components.keys(): + if name not in pipe_loaded._optional_components: + assert name in str(cap_logger) + + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + if not hasattr(self.pipeline_class, "_optional_components"): + return + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware + # This attribute is not serialized in the config of the pipeline + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator + def test_save_load_float16(self, expected_max_diff=1e-2): + components = self.get_dummy_components() + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware + # This attribute is not serialized in the config of the pipeline + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16, resolution=32) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) diff --git a/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py b/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py new file mode 100644 index 0000000000..0cd714af17 --- /dev/null +++ b/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py @@ -0,0 +1,312 @@ +import random +import tempfile +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxTransformer2DModel, + VisualClozeGenerationPipeline, +) +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + CaptureLogger, + enable_full_determinism, + floats_tensor, + require_accelerator, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class VisualClozeGenerationPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = VisualClozeGenerationPipeline + params = frozenset( + [ + "task_prompt", + "content_prompt", + "guidance_scale", + "prompt_embeds", + "pooled_prompt_embeds", + ] + ) + batch_params = frozenset(["task_prompt", "content_prompt", "image"]) + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=12, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=6, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[2, 2, 2], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + "resolution": 32, + } + + def get_dummy_inputs(self, device, seed=0): + # Create example images to simulate the input format required by VisualCloze + context_image = [ + Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8)) + for _ in range(2) + ] + query_image = [ + Image.fromarray( + floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8) + ), + None, + ] + + # Create an image list that conforms to the VisualCloze input format + image = [ + context_image, # In-Context example + query_image, # Query image + ] + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.", + "content_prompt": "A beautiful landscape with mountains and a lake", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "max_sequence_length": 77, + "output_type": "np", + } + return inputs + + def test_visualcloze_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["task_prompt"] = "A different task to perform." + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different + assert max_diff > 1e-6 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=1e-3) + + def test_different_task_prompts(self, expected_min_diff=1e-1): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_original = pipe(**inputs).images[0] + + inputs["task_prompt"] = "A different task description for image generation" + output_different_task = pipe(**inputs).images[0] + + # Different task prompts should produce different outputs + max_diff = np.abs(output_original - output_different_task).max() + assert max_diff > expected_min_diff + + def test_save_load_local(self, expected_max_difference=5e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(diffusers.logging.INFO) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + + with CaptureLogger(logger) as cap_logger: + # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware + # This attribute is not serialized in the config of the pipeline + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32) + + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + for name in pipe_loaded.components.keys(): + if name not in pipe_loaded._optional_components: + assert name in str(cap_logger) + + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + if not hasattr(self.pipeline_class, "_optional_components"): + return + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware + # This attribute is not serialized in the config of the pipeline + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator + def test_save_load_float16(self, expected_max_diff=1e-2): + components = self.get_dummy_components() + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware + # This attribute is not serialized in the config of the pipeline + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16, resolution=32) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) + + @unittest.skip("Skipped due to missing layout_prompt. Needs further investigation.") + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=0.0001, rtol=0.0001): + pass