diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml index 1074a5bba6..f928e123aa 100644 --- a/.github/workflows/build_docker_images.yml +++ b/.github/workflows/build_docker_images.yml @@ -25,7 +25,7 @@ jobs: if: github.event_name == 'pull_request' steps: - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Check out code uses: actions/checkout@v6 @@ -101,14 +101,14 @@ jobs: - name: Checkout repository uses: actions/checkout@v6 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ env.REGISTRY }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build and push - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v6 with: no-cache: true context: ./docker/${{ matrix.image-name }} diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml index 3bdfb4ca99..89b502d364 100644 --- a/.github/workflows/pr_modular_tests.yml +++ b/.github/workflows/pr_modular_tests.yml @@ -75,9 +75,27 @@ jobs: if: ${{ failure() }} run: | echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY + check_auto_docs: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.10" + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[quality] + - name: Check auto docs + run: make modular-autodoctrings + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Auto docstring checks failed. Please run `python utils/modular_auto_docstring.py --fix_and_overwrite`." >> $GITHUB_STEP_SUMMARY run_fast_tests: - needs: [check_code_quality, check_repository_consistency] + needs: [check_code_quality, check_repository_consistency, check_auto_docs] name: Fast PyTorch Modular Pipeline CPU tests runs-on: diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 64bac65353..87ea38a5bb 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -11,4 +11,4 @@ jobs: - uses: actions/checkout@v6 - name: typos-action - uses: crate-ci/typos@v1.12.4 + uses: crate-ci/typos@v1.42.1 diff --git a/Makefile b/Makefile index 9af2e8b1a5..b90ff82ab2 100644 --- a/Makefile +++ b/Makefile @@ -70,6 +70,10 @@ fix-copies: python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite +# Auto docstrings in modular blocks +modular-autodoctrings: + python utils/modular_auto_docstring.py + # Run tests for the library test: diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile index a700d1db72..134b47215d 100644 --- a/docker/diffusers-pytorch-cuda/Dockerfile +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04 LABEL maintainer="Hugging Face" LABEL repository="diffusers" -ARG PYTHON_VERSION=3.12 +ARG PYTHON_VERSION=3.11 ENV DEBIAN_FRONTEND=noninteractive RUN apt-get -y update \ @@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV} ENV PATH="$VIRTUAL_ENV/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +# Install torch, torchvision, and torchaudio together to ensure compatibility RUN uv pip install --no-cache-dir \ torch \ torchvision \ - torchaudio + torchaudio \ + --index-url https://download.pytorch.org/whl/cu121 RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]" diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile index eae7eaf4fa..7714821c82 100644 --- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile +++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile @@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04 LABEL maintainer="Hugging Face" LABEL repository="diffusers" -ARG PYTHON_VERSION=3.12 +ARG PYTHON_VERSION=3.11 ENV DEBIAN_FRONTEND=noninteractive RUN apt-get -y update \ @@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV} ENV PATH="$VIRTUAL_ENV/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +# Install torch, torchvision, and torchaudio together to ensure compatibility RUN uv pip install --no-cache-dir \ torch \ torchvision \ - torchaudio + torchaudio \ + --index-url https://download.pytorch.org/whl/cu121 RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]" diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 16f1a5d1ec..6cdd66a0f2 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -478,7 +478,7 @@ class PeftAdapterMixin: Args: adapter_names (`List[str]` or `str`): The names of the adapters to use. - adapter_weights (`Union[List[float], float]`, *optional*): + weights (`Union[List[float], float]`, *optional*): The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the adapters. @@ -495,7 +495,7 @@ class PeftAdapterMixin: "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + pipeline.unet.set_adapters(["cinematic", "pixel"], weights=[0.5, 0.5]) ``` """ if not USE_PEFT_BACKEND: diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index c733bd489d..ab4340fed1 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -152,6 +152,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "default_subfolder": "transformer", }, + "WanAnimateTransformer3DModel": { + "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, + "default_subfolder": "transformer", + }, "AutoencoderKLWan": { "checkpoint_mapping_fn": convert_wan_vae_to_diffusers, "default_subfolder": "vae", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 26f6c4388d..5e11acb51c 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -136,6 +136,7 @@ CHECKPOINT_KEY_NAMES = { "wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan_vae": "decoder.middle.0.residual.0.gamma", "wan_vace": "vace_blocks.0.after_proj.bias", + "wan_animate": "motion_encoder.dec.direction.weight", "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias", "cosmos-1.0": [ "net.x_embedder.proj.1.weight", @@ -219,6 +220,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, + "wan-animate-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.2-Animate-14B-Diffusers"}, "wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"}, "wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"}, "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"}, @@ -759,6 +761,9 @@ def infer_diffusers_model_type(checkpoint): elif checkpoint[target_key].shape[0] == 5120: model_type = "wan-vace-14B" + if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint: + model_type = "wan-animate-14B" + elif checkpoint[target_key].shape[0] == 1536: model_type = "wan-t2v-1.3B" elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16: @@ -3154,13 +3159,64 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): + def generate_motion_encoder_mappings(): + mappings = { + "motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight", + "motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight", + "motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias", + "motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight", + "motion_encoder.enc.fc": "motion_encoder.motion_network", + } + + for i in range(7): + conv_idx = i + 1 + mappings.update( + { + f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.0.weight": f"motion_encoder.res_blocks.{i}.conv1.weight", + f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.1.bias": f"motion_encoder.res_blocks.{i}.conv1.act_fn.bias", + f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.1.weight": f"motion_encoder.res_blocks.{i}.conv2.weight", + f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.2.bias": f"motion_encoder.res_blocks.{i}.conv2.act_fn.bias", + f"motion_encoder.enc.net_app.convs.{conv_idx}.skip.1.weight": f"motion_encoder.res_blocks.{i}.conv_skip.weight", + } + ) + + return mappings + + def generate_face_adapter_mappings(): + return { + "face_adapter.fuser_blocks": "face_adapter", + ".k_norm.": ".norm_k.", + ".q_norm.": ".norm_q.", + ".linear1_q.": ".to_q.", + ".linear2.": ".to_out.", + "conv1_local.conv": "conv1_local", + "conv2.conv": "conv2", + "conv3.conv": "conv3", + } + + def split_tensor_handler(key, state_dict, split_pattern, target_keys): + tensor = state_dict.pop(key) + split_idx = tensor.shape[0] // 2 + + new_key_1 = key.replace(split_pattern, target_keys[0]) + new_key_2 = key.replace(split_pattern, target_keys[1]) + + state_dict[new_key_1] = tensor[:split_idx] + state_dict[new_key_2] = tensor[split_idx:] + + def reshape_bias_handler(key, state_dict): + if "motion_encoder.enc.net_app.convs." in key and ".bias" in key: + state_dict[key] = state_dict[key][0, :, 0, 0] + converted_state_dict = {} + # Strip model.diffusion_model prefix keys = list(checkpoint.keys()) for k in keys: if "model.diffusion_model." in k: checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + # Base transformer mappings TRANSFORMER_KEYS_RENAME_DICT = { "time_embedding.0": "condition_embedder.time_embedder.linear_1", "time_embedding.2": "condition_embedder.time_embedder.linear_2", @@ -3182,28 +3238,43 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): "ffn.0": "ffn.net.0.proj", "ffn.2": "ffn.net.2", # Hack to swap the layer names - # The original model calls the norms in following order: norm1, norm3, norm2 - # We convert it to: norm1, norm2, norm3 "norm2": "norm__placeholder", "norm3": "norm2", "norm__placeholder": "norm3", - # For the I2V model + # I2V model "img_emb.proj.0": "condition_embedder.image_embedder.norm1", "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", "img_emb.proj.4": "condition_embedder.image_embedder.norm2", - # For the VACE model + # VACE model "before_proj": "proj_in", "after_proj": "proj_out", } + SPECIAL_KEYS_HANDLERS = {} + if any("face_adapter" in k for k in checkpoint.keys()): + TRANSFORMER_KEYS_RENAME_DICT.update(generate_face_adapter_mappings()) + SPECIAL_KEYS_HANDLERS[".linear1_kv."] = (split_tensor_handler, [".to_k.", ".to_v."]) + + if any("motion_encoder" in k for k in checkpoint.keys()): + TRANSFORMER_KEYS_RENAME_DICT.update(generate_motion_encoder_mappings()) + for key in list(checkpoint.keys()): - new_key = key[:] + reshape_bias_handler(key, checkpoint) + + for key in list(checkpoint.keys()): + new_key = key for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = checkpoint.pop(key) + for key in list(converted_state_dict.keys()): + for pattern, (handler_fn, target_keys) in SPECIAL_KEYS_HANDLERS.items(): + if pattern not in key: + continue + handler_fn(key, converted_state_dict, pattern, target_keys) + break + return converted_state_dict diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index c0b4ad4005..a98cb49114 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module): hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor.contiguous()) + # Only use contiguous() during training to avoid DDP gradient stride mismatch warning. + # In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU. + # Issue: https://github.com/huggingface/diffusers/issues/12975 + if self.training: + input_tensor = input_tensor.contiguous() + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 2696f5e787..3d38da1dfc 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed @@ -400,12 +400,14 @@ class LongCatImageTransformer2DModel( PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, + AttentionMixin, ): """ The Transformer model introduced in Longcat-Image. """ _supports_gradient_checkpointing = True + _repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py index 8860f4bca9..538a029cd8 100644 --- a/src/diffusers/models/transformers/transformer_wan_animate.py +++ b/src/diffusers/models/transformers/transformer_wan_animate.py @@ -166,9 +166,11 @@ class MotionConv2d(nn.Module): # NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates # set to 1, which should be equivalent to a 2D convolution expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1) + x = x.to(expanded_kernel.dtype) x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels) # Main Conv2D with scaling + x = x.to(self.weight.dtype) x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) # Activation with fused bias, if using @@ -338,8 +340,7 @@ class WanAnimateMotionEncoder(nn.Module): weight = self.motion_synthesis_weight + 1e-8 # Upcast the QR orthogonalization operation to FP32 original_motion_dtype = motion_feat.dtype - motion_feat = motion_feat.to(torch.float32) - weight = weight.to(torch.float32) + motion_feat = motion_feat.to(weight.dtype) Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device) @@ -769,7 +770,7 @@ class WanImageEmbedding(torch.nn.Module): return hidden_states -# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding +# Modified from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding class WanTimeTextImageEmbedding(nn.Module): def __init__( self, @@ -803,10 +804,12 @@ class WanTimeTextImageEmbedding(nn.Module): if timestep_seq_len is not None: timestep = timestep.unflatten(0, (-1, timestep_seq_len)) - time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype - if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: - timestep = timestep.to(time_embedder_dtype) - temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + if self.time_embedder.linear_1.weight.dtype.is_floating_point: + time_embedder_dtype = self.time_embedder.linear_1.weight.dtype + else: + time_embedder_dtype = encoder_hidden_states.dtype + + temb = self.time_embedder(timestep.to(time_embedder_dtype)).type_as(encoder_hidden_states) timestep_proj = self.time_proj(self.act_fn(temb)) encoder_hidden_states = self.text_embedder(encoder_hidden_states) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index aa421a5372..f3b12d7161 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -18,6 +18,7 @@ from collections import OrderedDict from dataclasses import dataclass, field, fields from typing import Any, Dict, List, Literal, Optional, Type, Union +import PIL.Image import torch from ..configuration_utils import ConfigMixin, FrozenDict @@ -323,11 +324,192 @@ class ConfigSpec: description: Optional[str] = None -# YiYi Notes: both inputs and intermediate_inputs are InputParam objects -# however some fields are not relevant for intermediate_inputs -# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed -# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs -# -> should we use different class for inputs and intermediate_inputs? +# ====================================================== +# InputParam and OutputParam templates +# ====================================================== + +INPUT_PARAM_TEMPLATES = { + "prompt": { + "type_hint": str, + "required": True, + "description": "The prompt or prompts to guide image generation.", + }, + "negative_prompt": { + "type_hint": str, + "description": "The prompt or prompts not to guide the image generation.", + }, + "max_sequence_length": { + "type_hint": int, + "default": 512, + "description": "Maximum sequence length for prompt encoding.", + }, + "height": { + "type_hint": int, + "description": "The height in pixels of the generated image.", + }, + "width": { + "type_hint": int, + "description": "The width in pixels of the generated image.", + }, + "num_inference_steps": { + "type_hint": int, + "default": 50, + "description": "The number of denoising steps.", + }, + "num_images_per_prompt": { + "type_hint": int, + "default": 1, + "description": "The number of images to generate per prompt.", + }, + "generator": { + "type_hint": torch.Generator, + "description": "Torch generator for deterministic generation.", + }, + "sigmas": { + "type_hint": List[float], + "description": "Custom sigmas for the denoising process.", + }, + "strength": { + "type_hint": float, + "default": 0.9, + "description": "Strength for img2img/inpainting.", + }, + "image": { + "type_hint": Union[PIL.Image.Image, List[PIL.Image.Image]], + "required": True, + "description": "Reference image(s) for denoising. Can be a single image or list of images.", + }, + "latents": { + "type_hint": torch.Tensor, + "description": "Pre-generated noisy latents for image generation.", + }, + "timesteps": { + "type_hint": torch.Tensor, + "description": "Timesteps for the denoising process.", + }, + "output_type": { + "type_hint": str, + "default": "pil", + "description": "Output format: 'pil', 'np', 'pt'.", + }, + "attention_kwargs": { + "type_hint": Dict[str, Any], + "description": "Additional kwargs for attention processors.", + }, + "denoiser_input_fields": { + "name": None, + "kwargs_type": "denoiser_input_fields", + "description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + }, + # inpainting + "mask_image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Mask image for inpainting.", + }, + "padding_mask_crop": { + "type_hint": int, + "description": "Padding for mask cropping in inpainting.", + }, + # controlnet + "control_image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Control image for ControlNet conditioning.", + }, + "control_guidance_start": { + "type_hint": float, + "default": 0.0, + "description": "When to start applying ControlNet.", + }, + "control_guidance_end": { + "type_hint": float, + "default": 1.0, + "description": "When to stop applying ControlNet.", + }, + "controlnet_conditioning_scale": { + "type_hint": float, + "default": 1.0, + "description": "Scale for ControlNet conditioning.", + }, + "layers": { + "type_hint": int, + "default": 4, + "description": "Number of layers to extract from the image", + }, + # common intermediate inputs + "prompt_embeds": { + "type_hint": torch.Tensor, + "required": True, + "description": "text embeddings used to guide the image generation. Can be generated from text_encoder step.", + }, + "prompt_embeds_mask": { + "type_hint": torch.Tensor, + "required": True, + "description": "mask for the text embeddings. Can be generated from text_encoder step.", + }, + "negative_prompt_embeds": { + "type_hint": torch.Tensor, + "description": "negative text embeddings used to guide the image generation. Can be generated from text_encoder step.", + }, + "negative_prompt_embeds_mask": { + "type_hint": torch.Tensor, + "description": "mask for the negative text embeddings. Can be generated from text_encoder step.", + }, + "image_latents": { + "type_hint": torch.Tensor, + "required": True, + "description": "image latents used to guide the image generation. Can be generated from vae_encoder step.", + }, + "batch_size": { + "type_hint": int, + "default": 1, + "description": "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + }, + "dtype": { + "type_hint": torch.dtype, + "default": torch.float32, + "description": "The dtype of the model inputs, can be generated in input step.", + }, +} + +OUTPUT_PARAM_TEMPLATES = { + "images": { + "type_hint": List[PIL.Image.Image], + "description": "Generated images.", + }, + "latents": { + "type_hint": torch.Tensor, + "description": "Denoised latents.", + }, + # intermediate outputs + "prompt_embeds": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The prompt embeddings.", + }, + "prompt_embeds_mask": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The encoder attention mask.", + }, + "negative_prompt_embeds": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The negative prompt embeddings.", + }, + "negative_prompt_embeds_mask": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The negative prompt embeddings mask.", + }, + "image_latents": { + "type_hint": torch.Tensor, + "description": "The latent representation of the input image.", + }, +} + + @dataclass class InputParam: """Specification for an input parameter.""" @@ -337,11 +519,31 @@ class InputParam: default: Any = None required: bool = False description: str = "" - kwargs_type: str = None # YiYi Notes: remove this feature (maybe) + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + @classmethod + def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam": + """Get template for name if exists, otherwise raise ValueError.""" + if template_name not in INPUT_PARAM_TEMPLATES: + raise ValueError(f"InputParam template for {template_name} not found") + + template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy() + + # Determine the actual param name: + # 1. From overrides if provided + # 2. From template if present + # 3. Fall back to template_name + name = overrides.pop("name", template_kwargs.pop("name", template_name)) + + if note and "description" in template_kwargs: + template_kwargs["description"] = f"{template_kwargs['description']} ({note})" + + template_kwargs.update(overrides) + return cls(name=name, **template_kwargs) + @dataclass class OutputParam: @@ -350,13 +552,33 @@ class OutputParam: name: str type_hint: Any = None description: str = "" - kwargs_type: str = None # YiYi notes: remove this feature (maybe) + kwargs_type: str = None def __repr__(self): return ( f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" ) + @classmethod + def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam": + """Get template for name if exists, otherwise raise ValueError.""" + if template_name not in OUTPUT_PARAM_TEMPLATES: + raise ValueError(f"OutputParam template for {template_name} not found") + + template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy() + + # Determine the actual param name: + # 1. From overrides if provided + # 2. From template if present + # 3. Fall back to template_name + name = overrides.pop("name", template_kwargs.pop("name", template_name)) + + if note and "description" in template_kwargs: + template_kwargs["description"] = f"{template_kwargs['description']} ({note})" + + template_kwargs.update(overrides) + return cls(name=name, **template_kwargs) + def format_inputs_short(inputs): """ @@ -509,10 +731,12 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115): desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description) wrapped_desc = wrap_text(desc, desc_indent, max_line_length) param_str += f"\n{desc_indent}{wrapped_desc}" + else: + param_str += f"\n{desc_indent}TODO: Add description." formatted_params.append(param_str) - return "\n\n".join(formatted_params) + return "\n".join(formatted_params) def format_input_params(input_params, indent_level=4, max_line_length=115): @@ -582,7 +806,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty loading_field_values = [] for field_name in component.loading_fields(): field_value = getattr(component, field_name) - if field_value is not None: + if field_value: loading_field_values.append(f"{field_name}={field_value}") # Add loading field information if available @@ -669,17 +893,17 @@ def make_doc_string( # Add description if description: desc_lines = description.strip().split("\n") - aligned_desc = "\n".join(" " + line for line in desc_lines) + aligned_desc = "\n".join(" " + line.rstrip() for line in desc_lines) output += aligned_desc + "\n\n" # Add components section if provided if expected_components and len(expected_components) > 0: - components_str = format_components(expected_components, indent_level=2) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) output += components_str + "\n\n" # Add configs section if provided if expected_configs and len(expected_configs) > 0: - configs_str = format_configs(expected_configs, indent_level=2) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) output += configs_str + "\n\n" # Add inputs section diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index d9c8cbb01d..80a379da6b 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -118,7 +118,40 @@ def get_timesteps(scheduler, num_inference_steps, strength): # ==================== +# auto_docstring class QwenImagePrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise for the generation process + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + model_name = "qwenimage" @property @@ -134,28 +167,20 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents"), - InputParam(name="height"), - InputParam(name="width"), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="generator"), - InputParam( - name="batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", - ), - InputParam( - name="dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs, can be generated in input step.", - ), + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), OutputParam( name="latents", type_hint=torch.Tensor, @@ -209,7 +234,42 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise (B, layers+1, C, H, W) for the generation process + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + model_name = "qwenimage-layered" @property @@ -225,29 +285,21 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents"), - InputParam(name="height"), - InputParam(name="width"), - InputParam(name="layers", default=4), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="generator"), - InputParam( - name="batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", - ), - InputParam( - name="dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs, can be generated in input step.", - ), + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("layers"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), OutputParam( name="latents", type_hint=torch.Tensor, @@ -301,7 +353,31 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): + """ + Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, + prepare_latents. Both noise and image latents should alreadybe patchified. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + """ + model_name = "qwenimage" @property @@ -323,12 +399,7 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): type_hint=torch.Tensor, description="The initial random noised, can be generated in prepare latent step.", ), - InputParam( - name="image_latents", - required=True, - type_hint=torch.Tensor, - description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.", - ), + InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."), InputParam( name="timesteps", required=True, @@ -345,6 +416,11 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): type_hint=torch.Tensor, description="The initial random noised used for inpainting denoising.", ), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The scaled noisy latents to use for inpainting/image-to-image denoising.", + ), ] @staticmethod @@ -382,7 +458,29 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): + """ + Step that creates mask latents from preprocessed mask_image by interpolating to latent space. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage" @property @@ -404,9 +502,9 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): type_hint=torch.Tensor, description="The processed mask to use for the inpainting process.", ), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="dtype", required=True), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("dtype"), ] @property @@ -450,7 +548,27 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): # ==================== +# auto_docstring class QwenImageSetTimestepsStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The initial random noised latents for the denoising process. Can be generated in prepare latents step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process + """ + model_name = "qwenimage" @property @@ -466,13 +584,13 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_inference_steps", default=50), - InputParam(name="sigmas"), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), InputParam( name="latents", required=True, type_hint=torch.Tensor, - description="The latents to use for the denoising process, used to calculate the image sequence length.", + description="The initial random noised latents for the denoising process. Can be generated in prepare latents step.", ), ] @@ -516,7 +634,27 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): + """ + Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + """ + model_name = "qwenimage-layered" @property @@ -532,15 +670,17 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("num_inference_steps", default=50, type_hint=int), - InputParam("sigmas", type_hint=List[float]), - InputParam("image_latents", required=True, type_hint=torch.Tensor), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam.template("image_latents"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="timesteps", type_hint=torch.Tensor), + OutputParam( + name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process." + ), ] @torch.no_grad() @@ -574,7 +714,32 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare + latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The latents to use for the denoising process. Can be generated in prepare latents step. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + num_inference_steps (`int`): + The number of denoising steps to perform at inference time. Updated based on strength. + """ + model_name = "qwenimage" @property @@ -590,15 +755,15 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_inference_steps", default=50), - InputParam(name="sigmas"), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), InputParam( - name="latents", + "latents", required=True, type_hint=torch.Tensor, - description="The latents to use for the denoising process, used to calculate the image sequence length.", + description="The latents to use for the denoising process. Can be generated in prepare latents step.", ), - InputParam(name="strength", default=0.9), + InputParam.template("strength", default=0.9), ] @property @@ -607,7 +772,12 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): OutputParam( name="timesteps", type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + description="The timesteps to use for the denoising process.", + ), + OutputParam( + name="num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time. Updated based on strength.", ), ] @@ -654,7 +824,29 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): ## RoPE inputs for denoiser +# auto_docstring class QwenImageRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the images latents, used for RoPE calculation + """ + model_name = "qwenimage" @property @@ -666,11 +858,11 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -702,7 +894,34 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after + prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`int`): + The height of the reference image. Can be generated in input step. + image_width (`int`): + The width of the reference image. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the images latents, used for RoPE calculation + """ + model_name = "qwenimage" @property @@ -712,13 +931,23 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="image_height", required=True), - InputParam(name="image_width", required=True), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=int, + description="The height of the reference image. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=int, + description="The width of the reference image. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -756,7 +985,39 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus. + Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. Should be placed + after prepare_latents step. + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`List`): + The heights of the reference images. Can be generated in input step. + image_width (`List`): + The widths of the reference images. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ + model_name = "qwenimage-edit-plus" @property @@ -770,13 +1031,23 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="image_height", required=True, type_hint=List[int]), - InputParam(name="image_width", required=True, type_hint=List[int]), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=List[int], + description="The heights of the reference images. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=List[int], + description="The widths of the reference images. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -832,7 +1103,37 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + additional_t_cond (`Tensor`): + The additional t cond, used for RoPE calculation + """ + model_name = "qwenimage-layered" @property @@ -844,12 +1145,12 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="layers", required=True), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam.template("layers"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -914,7 +1215,34 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): ## ControlNet inputs for denoiser + + +# auto_docstring class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): + """ + step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step. + + Components: + controlnet (`QwenImageControlNetModel`) + + Inputs: + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + controlnet_keep (`List`): + The controlnet keep values + """ + model_name = "qwenimage" @property @@ -930,12 +1258,17 @@ class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("control_image_latents", required=True), + InputParam.template("control_guidance_start"), + InputParam.template("control_guidance_end"), + InputParam.template("controlnet_conditioning_scale"), InputParam( - "timesteps", + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam( + name="timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 24a88ebfca..1adbf6bdd3 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Any, Dict, List -import numpy as np -import PIL import torch from ...configuration_utils import FrozenDict @@ -31,7 +29,30 @@ logger = logging.get_logger(__name__) # after denoising loop (unpack latents) + + +# auto_docstring class QwenImageAfterDenoiseStep(ModularPipelineBlocks): + """ + Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, + channels, 1, height, width) + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + latents (`Tensor`): + The latents to decode, can be generated in the denoise step. + + Outputs: + latents (`Tensor`): + The denoisedlatents unpacked to B, C, 1, H, W + """ + model_name = "qwenimage" @property @@ -49,13 +70,21 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="height", required=True), - InputParam(name="width", required=True), + InputParam.template("height", required=True), + InputParam.template("width", required=True), InputParam( name="latents", required=True, type_hint=torch.Tensor, - description="The latents to decode, can be generated in the denoise step", + description="The latents to decode, can be generated in the denoise step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="latents", type_hint=torch.Tensor, description="The denoisedlatents unpacked to B, C, 1, H, W" ), ] @@ -72,7 +101,29 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + + Outputs: + latents (`Tensor`): + Denoised latents. (unpacked to B, C, layers+1, H, W) + """ + model_name = "qwenimage-layered" @property @@ -88,10 +139,21 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("height", required=True, type_hint=int), - InputParam("width", required=True, type_hint=int), - InputParam("layers", required=True, type_hint=int), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("layers"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("latents", note="unpacked to B, C, layers+1, H, W"), ] @torch.no_grad() @@ -112,7 +174,26 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): # decode step + + +# auto_docstring class QwenImageDecoderStep(ModularPipelineBlocks): + """ + Step that decodes the latents to images + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage" @property @@ -134,19 +215,13 @@ class QwenImageDecoderStep(ModularPipelineBlocks): name="latents", required=True, type_hint=torch.Tensor, - description="The latents to decode, can be generated in the denoise step", + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", ), ] @property - def intermediate_outputs(self) -> List[str]: - return [ - OutputParam( - "images", - type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], - description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", - ) - ] + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images", note="tensor output of the vae decoder.")] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -176,7 +251,26 @@ class QwenImageDecoderStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredDecoderStep(ModularPipelineBlocks): + """ + Decode unpacked latents (B, C, layers+1, H, W) into layer images. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-layered" @property @@ -198,14 +292,19 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("output_type", default="pil", type_hint=str), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", + ), + InputParam.template("output_type"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]), + OutputParam.template("images"), ] @torch.no_grad() @@ -251,7 +350,27 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks): # postprocess the decoded images + + +# auto_docstring class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" @property @@ -272,15 +391,19 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("images", required=True, description="the generated image from decoders step"), InputParam( - name="output_type", - default="pil", - type_hint=str, - description="The type of the output images, can be 'pil', 'np', 'pt'", + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", ), + InputParam.template("output_type"), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images")] + @staticmethod def check_inputs(output_type): if output_type not in ["pil", "np", "pt"]: @@ -301,7 +424,28 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image, optional apply the mask overally to the original image.. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" @property @@ -322,16 +466,24 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("images", required=True, description="the generated image from decoders step"), InputParam( - name="output_type", - default="pil", - type_hint=str, - description="The type of the output images, can be 'pil', 'np', 'pt'", + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", + ), + InputParam.template("output_type"), + InputParam( + name="mask_overlay_kwargs", + type_hint=Dict[str, Any], + description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep.", ), - InputParam("mask_overlay_kwargs"), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images")] + @staticmethod def check_inputs(output_type, mask_overlay_kwargs): if output_type not in ["pil", "np", "pt"]: diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index d6bcb4a94f..8579c9843a 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -50,7 +50,7 @@ class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks): def inputs(self) -> List[InputParam]: return [ InputParam( - "latents", + name="latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", @@ -80,17 +80,12 @@ class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks): def inputs(self) -> List[InputParam]: return [ InputParam( - "latents", + name="latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.", - ), + InputParam.template("image_latents"), ] @torch.no_grad() @@ -134,29 +129,12 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks): type_hint=torch.Tensor, description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", ), + InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."), InputParam( - "controlnet_conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", - ), - InputParam( - "controlnet_keep", + name="controlnet_keep", required=True, type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description=( - "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds." - ), + description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step.", ), ] @@ -217,28 +195,13 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("attention_kwargs"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The latents to use for the denoising process. Can be generated in prepare_latents step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), InputParam( "img_shapes", required=True, type_hint=List[Tuple[int, int]], - description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.", + description="The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.", ), ] @@ -317,23 +280,8 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("attention_kwargs"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The latents to use for the denoising process. Can be generated in prepare_latents step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), InputParam( "img_shapes", required=True, @@ -415,7 +363,7 @@ class QwenImageLoopAfterDenoiser(ModularPipelineBlocks): @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."), + OutputParam.template("latents"), ] @torch.no_grad() @@ -456,24 +404,19 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks): type_hint=torch.Tensor, description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.", ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.", - ), + InputParam.template("image_latents"), InputParam( "initial_noise", required=True, type_hint=torch.Tensor, description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.", ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", - ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("latents"), ] @torch.no_grad() @@ -515,17 +458,12 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): def loop_inputs(self) -> List[InputParam]: return [ InputParam( - "timesteps", + name="timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), + InputParam.template("num_inference_steps", required=True), ] @torch.no_grad() @@ -557,7 +495,42 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): # Qwen Image (text2image, image2image) + + +# auto_docstring class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2image and image2image tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ @@ -570,8 +543,8 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): @property def description(self) -> str: return ( - "Denoise step that iteratively denoise the latents. \n" - "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "Denoise step that iteratively denoise the latents.\n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n" "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `QwenImageLoopBeforeDenoiser`\n" " - `QwenImageLoopDenoiser`\n" @@ -581,7 +554,47 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image (inpainting) +# auto_docstring class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -606,7 +619,47 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image (text2image, image2image) with controlnet +# auto_docstring class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2img/img2img tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`List`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -631,7 +684,54 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image (inpainting) with controlnet +# auto_docstring class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`List`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -664,7 +764,42 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image Edit (image2image) +# auto_docstring class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, @@ -687,7 +822,47 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image Edit (inpainting) +# auto_docstring class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, @@ -712,7 +887,42 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image Layered (image2image) +# auto_docstring class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Layered. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageEditLoopBeforeDenoiser, diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 4b66dd32e5..5e1821cca5 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -30,7 +30,7 @@ from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions from ...utils import logging from ...utils.torch_utils import unwrap_module from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import QwenImageModularPipeline from .prompt_templates import ( QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, @@ -259,33 +259,47 @@ def encode_vae_image( # ==================== # 1. RESIZE # ==================== +# In QwenImage pipelines, resize is a separate step because the resized image is used in VL encoding and vae encoder blocks: +# +# image (PIL.Image.Image) +# │ +# ▼ +# resized_image ([PIL.Image.Image]) +# │ +# ├──► text_encoder ──► prompt_embeds, prompt_embeds_mask +# │ (VL encoding needs the resized image for vision-language fusion) +# │ +# └──► image_processor ──► processed_image (torch.Tensor, pixel space) +# │ +# ▼ +# vae_encoder ──► image_latents (torch.Tensor, latent space) +# +# In most of our other pipelines, resizing is done as part of the image preprocessing step. +# ==================== + + +# auto_docstring class QwenImageEditResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to target area while maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`List`): + The resized images + """ + model_name = "qwenimage-edit" - def __init__( - self, - input_name: str = "image", - output_name: str = "resized_image", - ): - """Create a configurable step for resizing images to the target area while maintaining the aspect ratio. - - Args: - input_name (str, optional): Name of the image field to read from the - pipeline state. Defaults to "image". - output_name (str, optional): Name of the resized image field to write - back to the pipeline state. Defaults to "resized_image". - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError( - f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" - ) - self._image_input_name = input_name - self._resized_image_output_name = output_name - super().__init__() - @property def description(self) -> str: - return f"Image Resize step that resize the {self._image_input_name} to target area while maintaining the aspect ratio." + return "Image Resize step that resize the image to target area while maintaining the aspect ratio." @property def expected_components(self) -> List[ComponentSpec]: @@ -300,17 +314,15 @@ class QwenImageEditResizeStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: - return [ - InputParam( - name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" - ), - ] + return [InputParam.template("image")] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" + name="resized_image", + type_hint=List[PIL.Image.Image], + description="The resized images", ), ] @@ -318,7 +330,7 @@ class QwenImageEditResizeStep(ModularPipelineBlocks): def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - images = getattr(block_state, self._image_input_name) + images = block_state.image if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") @@ -334,38 +346,36 @@ class QwenImageEditResizeStep(ModularPipelineBlocks): for image in images ] - setattr(block_state, self._resized_image_output_name, resized_images) + block_state.resized_image = resized_images self.set_block_state(state, block_state) return components, state +# auto_docstring class QwenImageLayeredResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while + maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + + Outputs: + resized_image (`List`): + The resized images + """ + model_name = "qwenimage-layered" - def __init__( - self, - input_name: str = "image", - output_name: str = "resized_image", - ): - """Create a configurable step for resizing images to the target area while maintaining the aspect ratio. - - Args: - input_name (str, optional): Name of the image field to read from the - pipeline state. Defaults to "image". - output_name (str, optional): Name of the resized image field to write - back to the pipeline state. Defaults to "resized_image". - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError( - f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" - ) - self._image_input_name = input_name - self._resized_image_output_name = output_name - super().__init__() - @property def description(self) -> str: - return f"Image Resize step that resize the {self._image_input_name} to target area while maintaining the aspect ratio." + return "Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while maintaining the aspect ratio." @property def expected_components(self) -> List[ComponentSpec]: @@ -381,9 +391,7 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam( - name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" - ), + InputParam.template("image"), InputParam( name="resolution", default=640, @@ -396,8 +404,10 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks): def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" - ), + name="resized_image", + type_hint=List[PIL.Image.Image], + description="The resized images", + ) ] @staticmethod @@ -411,7 +421,7 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks): self.check_inputs(resolution=block_state.resolution) - images = getattr(block_state, self._image_input_name) + images = block_state.image if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") @@ -428,45 +438,40 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks): for image in images ] - setattr(block_state, self._resized_image_output_name, resized_images) + block_state.resized_image = resized_images self.set_block_state(state, block_state) return components, state +# auto_docstring class QwenImageEditPlusResizeStep(ModularPipelineBlocks): - """Resize each image independently based on its own aspect ratio. For QwenImage Edit Plus.""" + """ + Resize images for QwenImage Edit Plus pipeline. + Produces two outputs: resized_image (1024x1024) for VAE encoding, resized_cond_image (384x384) for VL text + encoding. Each image is resized independently based on its own aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + """ model_name = "qwenimage-edit-plus" - def __init__( - self, - input_name: str = "image", - output_name: str = "resized_image", - target_area: int = 1024 * 1024, - ): - """Create a step for resizing images to a target area. - - Each image is resized independently based on its own aspect ratio. This is suitable for Edit Plus where - multiple reference images can have different dimensions. - - Args: - input_name (str, optional): Name of the image field to read. Defaults to "image". - output_name (str, optional): Name of the resized image field to write. Defaults to "resized_image". - target_area (int, optional): Target area in pixels. Defaults to 1024*1024. - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError( - f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" - ) - self._image_input_name = input_name - self._resized_image_output_name = output_name - self._target_area = target_area - super().__init__() - @property def description(self) -> str: return ( - f"Image Resize step that resizes {self._image_input_name} to target area {self._target_area}.\n" + "Resize images for QwenImage Edit Plus pipeline.\n" + "Produces two outputs: resized_image (1024x1024) for VAE encoding, " + "resized_cond_image (384x384) for VL text encoding.\n" "Each image is resized independently based on its own aspect ratio." ) @@ -483,20 +488,21 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: - return [ - InputParam( - name=self._image_input_name, - required=True, - type_hint=torch.Tensor, - description="The image(s) to resize", - ), - ] + # image + return [InputParam.template("image")] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" + name="resized_image", + type_hint=List[PIL.Image.Image], + description="Images resized to 1024x1024 target area for VAE encoding", + ), + OutputParam( + name="resized_cond_image", + type_hint=List[PIL.Image.Image], + description="Images resized to 384x384 target area for VL text encoding", ), ] @@ -504,7 +510,7 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks): def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - images = getattr(block_state, self._image_input_name) + images = block_state.image if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") @@ -514,16 +520,22 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks): # Resize each image independently based on its own aspect ratio resized_images = [] + resized_cond_images = [] for image in images: image_width, image_height = image.size - calculated_width, calculated_height, _ = calculate_dimensions( - self._target_area, image_width / image_height - ) - resized_images.append( - components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + + # For VAE encoder (1024x1024 target area) + vae_width, vae_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + resized_images.append(components.image_resize_processor.resize(image, height=vae_height, width=vae_width)) + + # For VL text encoder (384x384 target area) + vl_width, vl_height, _ = calculate_dimensions(384 * 384, image_width / image_height) + resized_cond_images.append( + components.image_resize_processor.resize(image, height=vl_height, width=vl_width) ) - setattr(block_state, self._resized_image_output_name, resized_images) + block_state.resized_image = resized_images + block_state.resized_cond_image = resized_cond_images self.set_block_state(state, block_state) return components, state @@ -531,14 +543,38 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks): # ==================== # 2. GET IMAGE PROMPT # ==================== + + +# auto_docstring class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): """ - Auto-caption step that generates a text prompt from the input image if none is provided. Uses the VL model to - generate a description of the image. + Auto-caption step that generates a text prompt from the input image if none is provided. + Uses the VL model (text_encoder) to generate a description of the image. If prompt is already provided, this step + passes through unchanged. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + resized_image (`Image`): + The image to generate caption from, should be resized use the resize step + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + + Outputs: + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption """ model_name = "qwenimage-layered" + def __init__(self): + self.image_caption_prompt_en = QWENIMAGE_LAYERED_CAPTION_PROMPT_EN + self.image_caption_prompt_cn = QWENIMAGE_LAYERED_CAPTION_PROMPT_CN + super().__init__() + @property def description(self) -> str: return ( @@ -554,17 +590,12 @@ class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): ComponentSpec("processor", Qwen2VLProcessor), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec(name="image_caption_prompt_en", default=QWENIMAGE_LAYERED_CAPTION_PROMPT_EN), - ConfigSpec(name="image_caption_prompt_cn", default=QWENIMAGE_LAYERED_CAPTION_PROMPT_CN), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", type_hint=str, description="The prompt to encode"), + InputParam.template( + "prompt", required=False + ), # it is not required for qwenimage-layered, unlike other pipelines InputParam( name="resized_image", required=True, @@ -579,6 +610,16 @@ class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): ), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="prompt", + type_hint=str, + description="The prompt or prompts to guide image generation. If not provided, updated using image caption", + ), + ] + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -588,9 +629,9 @@ class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): # If prompt is empty or None, generate caption from image if block_state.prompt is None or block_state.prompt == "" or block_state.prompt == " ": if block_state.use_en_prompt: - caption_prompt = components.config.image_caption_prompt_en + caption_prompt = self.image_caption_prompt_en else: - caption_prompt = components.config.image_caption_prompt_cn + caption_prompt = self.image_caption_prompt_cn model_inputs = components.processor( text=caption_prompt, @@ -616,9 +657,44 @@ class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): # ==================== # 3. TEXT ENCODER # ==================== + + +# auto_docstring class QwenImageTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that generates text embeddings to guide the image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + model_name = "qwenimage" + def __init__(self): + self.prompt_template_encode = QWENIMAGE_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_PROMPT_TEMPLATE_START_IDX + self.tokenizer_max_length = 1024 + super().__init__() + @property def description(self) -> str: return "Text Encoder step that generates text embeddings to guide the image generation." @@ -636,51 +712,21 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks): ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_PROMPT_TEMPLATE), - ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_PROMPT_TEMPLATE_START_IDX), - ConfigSpec(name="tokenizer_max_length", default=1024), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), - InputParam( - name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024 - ), + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length", default=1024), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - name="prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The prompt embeddings", - ), - OutputParam( - name="prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The encoder attention mask", - ), - OutputParam( - name="negative_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings", - ), - OutputParam( - name="negative_prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings mask", - ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), ] @staticmethod @@ -709,9 +755,9 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks): components.text_encoder, components.tokenizer, prompt=block_state.prompt, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, - tokenizer_max_length=components.config.tokenizer_max_length, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, device=device, ) @@ -726,9 +772,9 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks): components.text_encoder, components.tokenizer, prompt=negative_prompt, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, - tokenizer_max_length=components.config.tokenizer_max_length, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, device=device, ) block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[ @@ -742,9 +788,42 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image + generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_image (`Image`): + The image prompt to encode, should be resized using resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + model_name = "qwenimage" + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX + super().__init__() + @property def description(self) -> str: return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation." @@ -762,18 +841,11 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks): ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE), - ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam.template("prompt"), + InputParam.template("negative_prompt"), InputParam( name="resized_image", required=True, @@ -785,30 +857,10 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks): @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - name="prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The prompt embeddings", - ), - OutputParam( - name="prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The encoder attention mask", - ), - OutputParam( - name="negative_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings", - ), - OutputParam( - name="negative_prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings mask", - ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), ] @staticmethod @@ -836,8 +888,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks): components.processor, prompt=block_state.prompt, image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -850,8 +902,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks): components.processor, prompt=negative_prompt, image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -859,11 +911,44 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): - """Text encoder for QwenImage Edit Plus (VL encoding with multiple images).""" + """ + Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together to generate text + embeddings for guiding image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_cond_image (`Tensor`): + The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using + resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-edit-plus" + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE + self.img_template_encode = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX + super().__init__() + @property def description(self) -> str: return ( @@ -884,19 +969,11 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE), - ConfigSpec(name="img_template_encode", default=QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE), - ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam.template("prompt"), + InputParam.template("negative_prompt"), InputParam( name="resized_cond_image", required=True, @@ -908,30 +985,10 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - name="prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The prompt embeddings", - ), - OutputParam( - name="prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The encoder attention mask", - ), - OutputParam( - name="negative_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings", - ), - OutputParam( - name="negative_prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings mask", - ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), ] @staticmethod @@ -959,9 +1016,9 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): components.processor, prompt=block_state.prompt, image=block_state.resized_cond_image, - prompt_template_encode=components.config.prompt_template_encode, - img_template_encode=components.config.img_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -975,9 +1032,9 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): components.processor, prompt=negative_prompt, image=block_state.resized_cond_image, - prompt_template_encode=components.config.prompt_template_encode, - img_template_encode=components.config.img_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) ) @@ -989,7 +1046,38 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): # ==================== # 4. IMAGE PREPROCESS # ==================== + + +# auto_docstring class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be + resized to the given height and width. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + model_name = "qwenimage" @property @@ -1010,18 +1098,26 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("mask_image", required=True), - InputParam("image", required=True), - InputParam("height"), - InputParam("width"), - InputParam("padding_mask_crop"), + InputParam.template("mask_image"), + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("padding_mask_crop"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="processed_image"), - OutputParam(name="processed_mask_image"), + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), OutputParam( name="mask_overlay_kwargs", type_hint=Dict, @@ -1061,7 +1157,32 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be + resized first. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + resized_image (`Image`): + The resized image. should be generated using a resize step + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + model_name = "qwenimage-edit" @property @@ -1082,16 +1203,25 @@ class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("mask_image", required=True), - InputParam("resized_image", required=True), - InputParam("padding_mask_crop"), + InputParam.template("mask_image"), + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The resized image. should be generated using a resize step", + ), + InputParam.template("padding_mask_crop"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="processed_image"), - OutputParam(name="processed_mask_image"), + OutputParam(name="processed_image", type_hint=torch.Tensor, description="The processed image"), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), OutputParam( name="mask_overlay_kwargs", type_hint=Dict, @@ -1119,7 +1249,27 @@ class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. will resize the image to the given height and width. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + processed_image (`Tensor`): + The processed image + """ + model_name = "qwenimage" @property @@ -1140,14 +1290,20 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("image", required=True), - InputParam("height"), - InputParam("width"), + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), ] @property def intermediate_outputs(self) -> List[OutputParam]: - return [OutputParam(name="processed_image")] + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] @staticmethod def check_inputs(height, width, vae_scale_factor): @@ -1177,7 +1333,23 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images needs to be resized first. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`List`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + model_name = "qwenimage-edit" @property @@ -1198,12 +1370,23 @@ class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam("resized_image", required=True), + InputParam( + name="resized_image", + required=True, + type_hint=List[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ), ] @property def intermediate_outputs(self) -> List[OutputParam]: - return [OutputParam(name="processed_image")] + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): @@ -1221,12 +1404,29 @@ class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of + processed images. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`List`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + model_name = "qwenimage-edit-plus" @property def description(self) -> str: - return "Image Preprocess step. Images can be resized first using QwenImageEditResizeStep." + return "Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of processed images." @property def expected_components(self) -> List[ComponentSpec]: @@ -1241,11 +1441,24 @@ class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: - return [InputParam("resized_image")] + return [ + InputParam( + name="resized_image", + required=True, + type_hint=List[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ) + ] @property def intermediate_outputs(self) -> List[OutputParam]: - return [OutputParam(name="processed_image")] + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): @@ -1263,7 +1476,7 @@ class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): processed_images.append( components.image_processor.preprocess(image=img, height=img_height, width=img_width) ) - block_state.processed_image = processed_images + if is_image_list: block_state.processed_image = processed_images else: @@ -1276,15 +1489,34 @@ class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): # ==================== # 5. VAE ENCODER # ==================== + + +# auto_docstring class QwenImageVaeEncoderStep(ModularPipelineBlocks): - """VAE encoder that handles both single images and lists of images with varied resolutions.""" + """ + VAE Encoder step that converts processed_image into latent representations image_latents. + Handles both single images and lists of images with varied resolutions. + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + processed_image (`Tensor`): + The image tensor to encode + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + """ model_name = "qwenimage" def __init__( self, - input_name: str = "processed_image", - output_name: str = "image_latents", + input: Optional[InputParam] = None, + output: Optional[OutputParam] = None, ): """Initialize a VAE encoder step for converting images to latent representations. @@ -1292,11 +1524,26 @@ class QwenImageVaeEncoderStep(ModularPipelineBlocks): a single tensor, outputs a single latent tensor. Args: - input_name (str, optional): Name of the input image tensor or list. Defaults to "processed_image". - output_name (str, optional): Name of the output latent tensor or list. Defaults to "image_latents". + input (InputParam, optional): Input parameter for the processed image. Defaults to "processed_image". + output (OutputParam, optional): Output parameter for the image latents. Defaults to "image_latents". """ - self._image_input_name = input_name - self._image_latents_output_name = output_name + if input is None: + input = InputParam( + name="processed_image", required=True, type_hint=torch.Tensor, description="The image tensor to encode" + ) + + if output is None: + output = OutputParam.template("image_latents") + + if not isinstance(input, InputParam): + raise ValueError(f"input must be InputParam but is {type(input)}") + if not isinstance(output, OutputParam): + raise ValueError(f"output must be OutputParam but is {type(output)}") + + self._input = input + self._output = output + self._image_input_name = input.name + self._image_latents_output_name = output.name super().__init__() @property @@ -1312,17 +1559,14 @@ class QwenImageVaeEncoderStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: - return [InputParam(self._image_input_name, required=True), InputParam("generator")] + return [ + self._input, # default is "processed_image" + InputParam.template("generator"), + ] @property def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - self._image_latents_output_name, - type_hint=torch.Tensor, - description="The latents representing the reference image(s). Single tensor or list depending on input.", - ) - ] + return [self._output] # default is "image_latents" @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -1359,7 +1603,30 @@ class QwenImageVaeEncoderStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts `control_image` into latent representations control_image_latents. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + model_name = "qwenimage" @property @@ -1383,10 +1650,10 @@ class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam("control_image", required=True), - InputParam("height"), - InputParam("width"), - InputParam("generator"), + InputParam.template("control_image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("generator"), ] return inputs @@ -1473,23 +1740,38 @@ class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): # ==================== # 6. PERMUTE LATENTS # ==================== + + +# auto_docstring class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks): - """Permute image latents from VAE format to Layered format.""" + """ + Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing. + + Inputs: + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. (permuted from [B, C, 1, H, W] to [B, 1, C, H, W]) + """ model_name = "qwenimage-layered" - def __init__(self, input_name: str = "image_latents"): - self._input_name = input_name - super().__init__() - @property def description(self) -> str: - return f"Permute {self._input_name} from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing." + return "Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing." @property def inputs(self) -> List[InputParam]: return [ - InputParam(self._input_name, required=True), + InputParam.template("image_latents"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("image_latents", note="permuted from [B, C, 1, H, W] to [B, 1, C, H, W]"), ] @torch.no_grad() @@ -1497,8 +1779,8 @@ class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks): block_state = self.get_block_state(state) # Permute: (B, C, 1, H, W) -> (B, 1, C, H, W) - latents = getattr(block_state, self._input_name) - setattr(block_state, self._input_name, latents.permute(0, 2, 1, 3, 4)) + latents = block_state.image_latents + block_state.image_latents = latents.permute(0, 2, 1, 3, 4) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 4a1cf3700c..818bbca5ed 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import List, Optional, Tuple import torch @@ -109,7 +109,44 @@ def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: in return height, width +# auto_docstring class QwenImageTextInputsStep(ModularPipelineBlocks): + """ + Text input processing step that standardizes text embeddings for the pipeline. + This step: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt) + + This block should be placed after all encoder steps to process the text embeddings before they are used in + subsequent pipeline steps. + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + """ + model_name = "qwenimage" @property @@ -129,26 +166,22 @@ class QwenImageTextInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"), - InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"), - InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"), - InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"), + InputParam.template("num_images_per_prompt"), + InputParam.template("prompt_embeds"), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds"), + InputParam.template("negative_prompt_embeds_mask"), ] @property - def intermediate_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - "batch_size", - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", - ), - OutputParam( - "dtype", - type_hint=torch.dtype, - description="Data type of model tensor inputs (determined by `prompt_embeds`)", - ), + OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"), + OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"), + OutputParam.template("prompt_embeds", note="batch-expanded"), + OutputParam.template("prompt_embeds_mask", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"), ] @staticmethod @@ -221,20 +254,76 @@ class QwenImageTextInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage: update height/width, expand batch, patchify.""" + """ + Input processing step that: + 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ model_name = "qwenimage" def __init__( self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, ): + # by default, process `image_latents` + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + if not isinstance(image_latent_inputs, list): - image_latent_inputs = [image_latent_inputs] + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + if not isinstance(additional_batch_inputs, list): - additional_batch_inputs = [additional_batch_inputs] + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs @@ -252,9 +341,9 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks): if self._image_latent_inputs or self._additional_batch_inputs: inputs_info = "\n\nConfigured inputs:" if self._image_latent_inputs: - inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" if self._additional_batch_inputs: - inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." @@ -269,23 +358,19 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="batch_size", required=True), - InputParam(name="height"), - InputParam(name="width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), ] - - for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) - - for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs return inputs @property def intermediate_outputs(self) -> List[OutputParam]: - return [ + outputs = [ OutputParam( name="image_height", type_hint=int, @@ -298,11 +383,43 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks): ), ] + # `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Process image latent inputs - for image_latent_input_name in self._image_latent_inputs: + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue @@ -331,7 +448,8 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks): setattr(block_state, image_latent_input_name, image_latent_tensor) # Process additional batch inputs (only batch expansion) - for input_name in self._additional_batch_inputs: + for input_param in self._additional_batch_inputs: + input_name = input_param.name input_tensor = getattr(block_state, input_name) if input_tensor is None: continue @@ -349,20 +467,76 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage Edit Plus: handles list of latents with different sizes.""" + """ + Input processing step for Edit Plus that: + 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch + 2. For additional batch inputs: Expands batch dimensions to match final batch size + Height/width defaults to last image in the list. + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`List`): + The image heights calculated from the image latents dimension + image_width (`List`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ model_name = "qwenimage-edit-plus" def __init__( self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + if not isinstance(image_latent_inputs, list): - image_latent_inputs = [image_latent_inputs] + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + if not isinstance(additional_batch_inputs, list): - additional_batch_inputs = [additional_batch_inputs] + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs @@ -381,9 +555,9 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): if self._image_latent_inputs or self._additional_batch_inputs: inputs_info = "\n\nConfigured inputs:" if self._image_latent_inputs: - inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" if self._additional_batch_inputs: - inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." @@ -398,23 +572,20 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="batch_size", required=True), - InputParam(name="height"), - InputParam(name="width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), ] - for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) - - for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs return inputs @property def intermediate_outputs(self) -> List[OutputParam]: - return [ + outputs = [ OutputParam( name="image_height", type_hint=List[int], @@ -427,11 +598,43 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): ), ] + # `height`/`width` are updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified, concatenated, and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified, concatenated, and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Process image latent inputs - for image_latent_input_name in self._image_latent_inputs: + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue @@ -476,7 +679,8 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): setattr(block_state, image_latent_input_name, packed_image_latent_tensors) # Process additional batch inputs (only batch expansion) - for input_name in self._additional_batch_inputs: + for input_param in self._additional_batch_inputs: + input_name = input_param.name input_tensor = getattr(block_state, input_name) if input_tensor is None: continue @@ -494,22 +698,75 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): return components, state -# YiYi TODO: support define config default component from the ModularPipeline level. -# it is same as QwenImageAdditionalInputsStep, but with layered pachifier. +# same as QwenImageAdditionalInputsStep, but with layered pachifier. + + +# auto_docstring class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier.""" + """ + Input processing step for Layered that: + 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch + size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ model_name = "qwenimage-layered" def __init__( self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + if not isinstance(image_latent_inputs, list): - image_latent_inputs = [image_latent_inputs] + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + if not isinstance(additional_batch_inputs, list): - additional_batch_inputs = [additional_batch_inputs] + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs @@ -527,9 +784,9 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): if self._image_latent_inputs or self._additional_batch_inputs: inputs_info = "\n\nConfigured inputs:" if self._image_latent_inputs: - inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" if self._additional_batch_inputs: - inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." @@ -544,21 +801,18 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="batch_size", required=True), + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), ] + # default is `image_latents` - for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) - - for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + inputs += self._image_latent_inputs + self._additional_batch_inputs return inputs @property def intermediate_outputs(self) -> List[OutputParam]: - return [ + outputs = [ OutputParam( name="image_height", type_hint=int, @@ -569,15 +823,44 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): type_hint=int, description="The image width calculated from the image latents dimension", ), - OutputParam(name="height", type_hint=int, description="The height of the image output"), - OutputParam(name="width", type_hint=int, description="The width of the image output"), ] + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified with layered pachifier and batch-expanded)", + ) + ) + + # Add outputs for additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Process image latent inputs - for image_latent_input_name in self._image_latent_inputs: + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue @@ -608,7 +891,8 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): setattr(block_state, image_latent_input_name, image_latent_tensor) # Process additional batch inputs (only batch expansion) - for input_name in self._additional_batch_inputs: + for input_param in self._additional_batch_inputs: + input_name = input_param.name input_tensor = getattr(block_state, input_name) if input_tensor is None: continue @@ -626,7 +910,34 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageControlNetInputsStep(ModularPipelineBlocks): + """ + prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps. + + Inputs: + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + control_image_latents (`Tensor`): + The control image latents (patchified and batch-expanded). + height (`int`): + if not provided, updated to control image height + width (`int`): + if not provided, updated to control image width + """ + model_name = "qwenimage" @property @@ -636,11 +947,28 @@ class QwenImageControlNetInputsStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="control_image_latents", required=True), - InputParam(name="batch_size", required=True), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="height"), - InputParam(name="width"), + InputParam( + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam.template("batch_size"), + InputParam.template("num_images_per_prompt"), + InputParam.template("height"), + InputParam.template("width"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="control_image_latents", + type_hint=torch.Tensor, + description="The control image latents (patchified and batch-expanded).", + ), + OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"), + OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"), ] @torch.no_grad() diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py index ebe0bbbd75..5837799d34 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -import PIL.Image import torch from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict, OutputParam +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam from .before_denoise import ( QwenImageControlNetBeforeDenoiserStep, QwenImageCreateMaskLatentsStep, @@ -59,11 +56,91 @@ logger = logging.get_logger(__name__) # ==================== -# 1. VAE ENCODER +# 1. TEXT ENCODER # ==================== +# auto_docstring +class QwenImageAutoTextEncoderStep(AutoPipelineBlocks): + """ + Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextEncoderStep()] + block_names = ["text_encoder"] + block_trigger_inputs = ["prompt"] + + @property + def description(self) -> str: + return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block." + " - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided." + " - if `prompt` is not provided, step will be skipped." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# auto_docstring class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for inpainting tasks. It: + - Resizes the image to the target size, based on `height` and `width`. + - Processes and updates `image` and `mask_image`. + - Creates `image_latents`. + + Components: + image_mask_processor (`InpaintProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage" block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()] block_names = ["preprocess", "encode"] @@ -78,7 +155,31 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): ) +# auto_docstring class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage" block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()] @@ -89,7 +190,6 @@ class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): return "Vae encoder step that preprocess andencode the image inputs into their latent representations." -# Auto VAE encoder class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] block_names = ["inpaint", "img2img"] @@ -107,7 +207,33 @@ class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): # optional controlnet vae encoder +# auto_docstring class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block. + - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided. + - if `control_image` is not provided, step will be skipped. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + block_classes = [QwenImageControlNetVaeEncoderStep] block_names = ["controlnet"] block_trigger_inputs = ["control_image"] @@ -123,14 +249,65 @@ class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): # ==================== -# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) # ==================== # assemble input steps +# auto_docstring class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the img2img denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + model_name = "qwenimage" - block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])] + block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()] block_names = ["text_inputs", "additional_inputs"] @property @@ -140,12 +317,69 @@ class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): " - update height/width based `image_latents`, patchify `image_latents`." +# auto_docstring class QwenImageInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the inpainting denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), QwenImageAdditionalInputsStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] ), ] block_names = ["text_inputs", "additional_inputs"] @@ -158,7 +392,42 @@ class QwenImageInpaintInputStep(SequentialPipelineBlocks): # assemble prepare latents steps +# auto_docstring class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the pachified latents `mask` based on the processedmask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage" block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] block_names = ["add_noise_to_latents", "create_mask_latents"] @@ -176,7 +445,49 @@ class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): # Qwen Image (text2image) +# auto_docstring class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), @@ -199,9 +510,63 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (inpainting) +# auto_docstring class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageInpaintInputStep(), @@ -226,9 +591,61 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (image2image) +# auto_docstring class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageImg2ImgInputStep(), @@ -253,9 +670,66 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (text2image) with controlnet +# auto_docstring class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), @@ -282,9 +756,72 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (inpainting) with controlnet +# auto_docstring class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageInpaintInputStep(), @@ -313,9 +850,70 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (image2image) with controlnet +# auto_docstring class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageImg2ImgInputStep(), @@ -344,6 +942,12 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Auto denoise step for QwenImage class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): @@ -402,19 +1006,36 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.template("latents"), ] # ==================== -# 3. DECODE +# 4. DECODE # ==================== # standard decode step works for most tasks except for inpaint +# auto_docstring class QwenImageDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -425,7 +1046,30 @@ class QwenImageDecodeStep(SequentialPipelineBlocks): # Inpaint decode step +# auto_docstring class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask + overally to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage" block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -452,11 +1096,11 @@ class QwenImageAutoDecodeStep(AutoPipelineBlocks): # ==================== -# 4. AUTO BLOCKS & PRESETS +# 5. AUTO BLOCKS & PRESETS # ==================== AUTO_BLOCKS = InsertableDict( [ - ("text_encoder", QwenImageTextEncoderStep()), + ("text_encoder", QwenImageAutoTextEncoderStep()), ("vae_encoder", QwenImageAutoVaeEncoderStep()), ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), ("denoise", QwenImageAutoCoreDenoiseStep()), @@ -465,7 +1109,89 @@ AUTO_BLOCKS = InsertableDict( ) +# auto_docstring class QwenImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage. + - for image-to-image generation, you need to provide `image` + - for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`. + - to run the controlnet workflow, you need to provide `control_image` + - for text-to-image generation, all you need to provide is `prompt` + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) controlnet (`QwenImageControlNetModel`) + control_image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + mask_image (`Image`, *optional*): + Mask image for inpainting. + image (`Union[Image, List]`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_image_latents (`Tensor`, *optional*): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" block_classes = AUTO_BLOCKS.values() @@ -476,7 +1202,7 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks): return ( "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n" + "- for image-to-image generation, you need to provide `image`\n" - + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n" + "- to run the controlnet workflow, you need to provide `control_image`\n" + "- for text-to-image generation, all you need to provide is `prompt`" ) @@ -484,5 +1210,5 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]), + OutputParam.template("images"), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py index 2683e64080..e1e5c43354 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Optional -import PIL.Image import torch from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict, OutputParam +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam from .before_denoise import ( QwenImageCreateMaskLatentsStep, QwenImageEditRoPEInputsStep, @@ -59,8 +58,35 @@ logger = logging.get_logger(__name__) # ==================== +# auto_docstring class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): - """VL encoder that takes both image and text prompts.""" + """ + QwenImage-Edit VL encoder step that encode the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`List`): + The resized images + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-edit" block_classes = [ @@ -80,7 +106,30 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): # Edit VAE encoder +# auto_docstring class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditResizeStep(), @@ -95,12 +144,46 @@ class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): # Edit Inpaint VAE encoder +# auto_docstring class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It: + - resize the image for target area (1024 * 1024) while maintaining the aspect ratio. + - process the resized image and mask image. + - create image latents. + + Components: + image_resize_processor (`VaeImageProcessor`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + mask_image (`Image`): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditResizeStep(), QwenImageEditInpaintProcessImagesInputStep(), - QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"), + QwenImageVaeEncoderStep(), ] block_names = ["resize", "preprocess", "encode"] @@ -137,11 +220,64 @@ class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks): # assemble input steps +# auto_docstring class QwenImageEditInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageTextInputsStep(), - QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]), + QwenImageAdditionalInputsStep(), ] block_names = ["text_inputs", "additional_inputs"] @@ -154,12 +290,71 @@ class QwenImageEditInputStep(SequentialPipelineBlocks): ) +# auto_docstring class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit inpaint denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageTextInputsStep(), QwenImageAdditionalInputsStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] ), ] block_names = ["text_inputs", "additional_inputs"] @@ -174,7 +369,42 @@ class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): # assemble prepare latents steps +# auto_docstring class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the patchified latents `mask` based on the processed mask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage-edit" block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] block_names = ["add_noise_to_latents", "create_mask_latents"] @@ -189,7 +419,50 @@ class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): # Qwen Image Edit (image2image) core denoise step +# auto_docstring class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditInputStep(), @@ -212,9 +485,62 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Core denoising workflow for QwenImage-Edit edit (img2img) task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image Edit (inpainting) core denoise step +# auto_docstring class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit inpaint task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditInpaintInputStep(), @@ -239,6 +565,12 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Core denoising workflow for QwenImage-Edit edit inpaint task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Auto core denoise step for QwenImage Edit class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): @@ -267,6 +599,12 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): "Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit." ) + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # ==================== # 4. DECODE @@ -274,7 +612,26 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): # Decode step (standard) +# auto_docstring class QwenImageEditDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage-edit" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -285,7 +642,30 @@ class QwenImageEditDecodeStep(SequentialPipelineBlocks): # Inpaint decode step +# auto_docstring class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask + overlay to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage-edit" block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -313,9 +693,7 @@ class QwenImageEditAutoDecodeStep(AutoPipelineBlocks): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.template("latents"), ] @@ -333,7 +711,66 @@ EDIT_AUTO_BLOCKS = InsertableDict( ) +# auto_docstring class QwenImageEditAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit. + - for edit (img2img) generation, you need to provide `image` + - for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide + `padding_mask_crop` + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + mask_image (`Image`, *optional*): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit" block_classes = EDIT_AUTO_BLOCKS.values() block_names = EDIT_AUTO_BLOCKS.keys() @@ -349,5 +786,5 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.template("images"), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py index 99c5b109bf..37656cef5d 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py @@ -12,11 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -import PIL.Image -import torch - from ...utils import logging from ..modular_pipeline import SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam @@ -53,12 +48,41 @@ logger = logging.get_logger(__name__) # ==================== +# auto_docstring class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): - """VL encoder that takes both image and text prompts. Uses 384x384 target area.""" + """ + QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-edit-plus" block_classes = [ - QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"), + QwenImageEditPlusResizeStep(), QwenImageEditPlusTextEncoderStep(), ] block_names = ["resize", "encode"] @@ -73,12 +97,36 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): # ==================== +# auto_docstring class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): - """VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area.""" + """ + VAE encoder step that encodes image inputs into latent representations. + Each image is resized independently based on its own aspect ratio to 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ model_name = "qwenimage-edit-plus" block_classes = [ - QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"), + QwenImageEditPlusResizeStep(), QwenImageEditPlusProcessImagesInputStep(), QwenImageVaeEncoderStep(), ] @@ -98,11 +146,66 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): # assemble input steps +# auto_docstring class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the Edit Plus denoising step. It: + - Standardizes text embeddings batch size. + - Processes list of image latents: patchifies, concatenates along dim=1, expands batch. + - Outputs lists of image_height/image_width for RoPE calculation. + - Defaults height/width from last image in the list. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`List`): + The image heights calculated from the image latents dimension + image_width (`List`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ + model_name = "qwenimage-edit-plus" block_classes = [ QwenImageTextInputsStep(), - QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]), + QwenImageEditPlusAdditionalInputsStep(), ] block_names = ["text_inputs", "additional_inputs"] @@ -118,7 +221,50 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks): # Qwen Image Edit Plus (image2image) core denoise step +# auto_docstring class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit Plus edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit-plus" block_classes = [ QwenImageEditPlusInputStep(), @@ -144,9 +290,7 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.template("latents"), ] @@ -155,7 +299,26 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): # ==================== +# auto_docstring class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocesses the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage-edit-plus" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -179,7 +342,53 @@ EDIT_PLUS_AUTO_BLOCKS = InsertableDict( ) +# auto_docstring class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus. + - `image` is required input (can be single image or list of images). + - Each image is resized independently based on its own aspect ratio. + - VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) + transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit-plus" block_classes = EDIT_PLUS_AUTO_BLOCKS.values() block_names = EDIT_PLUS_AUTO_BLOCKS.keys() @@ -196,5 +405,5 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.template("images"), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py index 63ee36df51..fdfeab0488 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py @@ -12,12 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from typing import List - -import PIL.Image -import torch - from ...utils import logging from ..modular_pipeline import SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam @@ -55,8 +49,44 @@ logger = logging.get_logger(__name__) # ==================== +# auto_docstring class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): - """Text encoder that takes text prompt, will generate a prompt based on image if not provided.""" + """ + QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not + provided. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + resized_image (`List`): + The resized images + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-layered" block_classes = [ @@ -77,7 +107,32 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): # Edit VAE encoder +# auto_docstring class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageLayeredResizeStep(), @@ -98,11 +153,60 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): # assemble input steps +# auto_docstring class QwenImageLayeredInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the layered denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageTextInputsStep(), - QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]), + QwenImageLayeredAdditionalInputsStep(), ] block_names = ["text_inputs", "additional_inputs"] @@ -116,7 +220,48 @@ class QwenImageLayeredInputStep(SequentialPipelineBlocks): # Qwen Image Layered (image2image) core denoise step +# auto_docstring class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Layered img2img task. + + Components: + pachifier (`QwenImageLayeredPachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageLayeredInputStep(), @@ -142,9 +287,7 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.template("latents"), ] @@ -162,7 +305,54 @@ LAYERED_AUTO_BLOCKS = InsertableDict( ) +# auto_docstring class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for layered denoising tasks using QwenImage-Layered. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) pachifier (`QwenImageLayeredPachifier`) + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-layered" block_classes = LAYERED_AUTO_BLOCKS.values() block_names = LAYERED_AUTO_BLOCKS.keys() @@ -174,5 +364,5 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.template("images"), ] diff --git a/src/diffusers/modular_pipelines/z_image/denoise.py b/src/diffusers/modular_pipelines/z_image/denoise.py index 3d5a00a9df..5f76a8459f 100644 --- a/src/diffusers/modular_pipelines/z_image/denoise.py +++ b/src/diffusers/modular_pipelines/z_image/denoise.py @@ -131,7 +131,7 @@ class ZImageLoopDenoiser(ModularPipelineBlocks): ), InputParam( kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", ), ] guider_input_names = [] diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py index 019c144152..3ea1ece36c 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py @@ -482,8 +482,6 @@ class ChromaInpaintPipeline( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, max_sequence_length=None, @@ -531,15 +529,6 @@ class ChromaInpaintPipeline( f" {negative_prompt_embeds.shape}." ) - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") @@ -793,13 +782,11 @@ class ChromaInpaintPipeline( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py index f4bd0cc2d7..23c0e138c4 100644 --- a/src/diffusers/schedulers/scheduling_consistency_decoder.py +++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py @@ -14,7 +14,7 @@ from .scheduling_utils import SchedulerMixin def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -28,8 +28,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 74ade1d8bb..92c3e20013 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index 92f7a5ab3a..1a77a65278 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -100,14 +100,13 @@ def betas_for_alpha_bar( return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): +def rescale_zero_terminal_snr(alphas_cumprod: torch.Tensor) -> torch.Tensor: """ - Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - + Rescales betas to have zero terminal SNR Based on (Algorithm 1)[https://huggingface.co/papers/2305.08891] Args: - betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + alphas_cumprod (`torch.Tensor`): + The alphas cumulative products that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR @@ -142,11 +141,11 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to 0.00085): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to 0.0120): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`str`, defaults to `"scaled_linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): @@ -179,6 +178,8 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + snr_shift_scale (`float`, defaults to 3.0): + Shift scale for SNR. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -190,15 +191,15 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.0120, - beta_schedule: str = "scaled_linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading", rescale_betas_zero_snr: bool = False, snr_shift_scale: float = 3.0, ): @@ -208,7 +209,15 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float64, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -238,7 +247,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) - def _get_variance(self, timestep, prev_timestep): + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t @@ -265,7 +274,11 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -317,7 +330,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: @@ -328,7 +341,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. @@ -487,5 +500,5 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 802d8f7977..476f741bcd 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -22,6 +22,7 @@ import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -32,6 +33,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DDIMSchedulerState: common: CommonSchedulerState @@ -125,6 +129,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: @@ -152,7 +160,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ) def scale_model_input( - self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DDIMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Args: @@ -190,7 +201,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + prev_timestep >= 0, + state.common.alphas_cumprod[prev_timestep], + state.final_alpha_cumprod, ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index e76ad9aa6c..a3c9ed1f62 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -49,7 +49,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -63,8 +63,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -99,7 +99,7 @@ def betas_for_alpha_bar( # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", clip_sample_range: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "trailing"] = "leading", rescale_betas_zero_snr: bool = False, **kwargs, ): @@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - eta (`float`): - The weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`, defaults to `False`): - If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary - because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no - clipping has happened, "corrected" `model_output` would coincide with the one provided as input and - `use_clipped_model_output` has no effect. - variance_noise (`torch.Tensor`): - Alternative to generating noise with `generator` by directly providing the noise for the variance - itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or `tuple`. @@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): # 1. get previous step value (=t+1) prev_timestep = timestep timestep = min( - timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1 + timestep - self.config.num_train_timesteps // self.num_inference_steps, + self.config.num_train_timesteps - 1, ) # 2. compute alphas, betas @@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): return (prev_sample, pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 09f55ee4c2..76f0636fbf 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -51,7 +51,7 @@ class DDIMParallelSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -101,7 +101,7 @@ def betas_for_alpha_bar( # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): """ return sample - def _get_variance(self, timestep, prev_timestep=None): + def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor: if prev_timestep is None: prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps @@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): return variance - def _batch_get_variance(self, t, prev_t): + def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor: alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) @@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): return sample # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMParallelSchedulerOutput, Tuple]: @@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): sample (`torch.Tensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped - predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when - `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would - coincide with the one provided as input and `use_clipped_model_output` will have not effect. - generator: random number generator. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This + correction is necessary because the predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches + the input and `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + Random number generator. variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://huggingface.co/papers/2210.05559) @@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): if variance_noise is None: variance_noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, ) variance = std_dev_t * variance_noise @@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): def batch_step_no_noise( self, model_output: torch.Tensor, - timesteps: List[int], + timesteps: torch.Tensor, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, @@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): direct output from learned diffusion model. - timesteps (`List[int]`): + timesteps (`torch.Tensor`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. @@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d0596bb918..2e2816bbf3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -48,7 +48,7 @@ class DDPMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -62,8 +62,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -192,7 +192,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: Literal[ - "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range" + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", ] = "fixed_small", clip_sample: bool = True, prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", @@ -210,7 +215,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -268,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): @@ -337,7 +350,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): t: int, predicted_variance: Optional[torch.Tensor] = None, variance_type: Optional[ - Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] + Literal[ + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", + ] ] = None, ) -> torch.Tensor: """ @@ -472,7 +492,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): prev_t = self.previous_timestep(t) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None @@ -521,7 +544,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): if t > 0: device = model_output.device variance_noise = randn_tensor( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=device, + dtype=model_output.dtype, ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise @@ -620,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): def __len__(self) -> int: return self.config.num_train_timesteps - def previous_timestep(self, timestep: int) -> int: + def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]: """ Compute the previous timestep in the diffusion chain. @@ -629,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index a3264f54f5..e02b7ea0c0 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -32,6 +33,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DDPMSchedulerState: common: CommonSchedulerState @@ -42,7 +46,12 @@ class DDPMSchedulerState: num_inference_steps: Optional[int] = None @classmethod - def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): + def create( + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) @@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: @@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ) def scale_model_input( - self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DDPMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Args: diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index ee7ab66be4..b02c5376f2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -50,7 +50,7 @@ class DDPMParallelSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -64,8 +64,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -149,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): For more details, see the original paper: https://huggingface.co/papers/2006.11239 Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - variance_type (`str`): - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + trained_betas (`np.ndarray`, *optional*): + Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`, defaults to `"fixed_small"`): + Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - clip_sample (`bool`, default `True`): - option to clip predicted sample for numerical stability. - clip_sample_range (`float`, default `1.0`): - the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + clip_sample (`bool`, defaults to `True`): + Option to clip predicted sample for numerical stability. + prediction_type (`str`, defaults to `"epsilon"`): + Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://huggingface.co/papers/2210.02303) - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method (introduced by Imagen, https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`. - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True`. - timestep_spacing (`str`, default `"leading"`): + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, default `0`): + steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and @@ -202,7 +205,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: Literal[ - "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range" + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", ] = "fixed_small", clip_sample: bool = True, prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", @@ -220,7 +228,15 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -280,7 +296,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): @@ -350,7 +366,14 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): t: int, predicted_variance: Optional[torch.Tensor] = None, variance_type: Optional[ - Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] + Literal[ + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", + ] ] = None, ) -> torch.Tensor: """ @@ -458,7 +481,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): model_output: torch.Tensor, timestep: int, sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[DDPMParallelSchedulerOutput, Tuple]: """ @@ -470,7 +493,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.Tensor`): current instance of sample being created by diffusion process. - generator: random number generator. + generator (`torch.Generator`, *optional*): + Random number generator. return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class Returns: @@ -483,7 +507,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): prev_t = self.previous_timestep(t) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None @@ -532,7 +559,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): if t > 0: device = model_output.device variance_noise = randn_tensor( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=device, + dtype=model_output.dtype, ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise @@ -555,7 +585,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): def batch_step_no_noise( self, model_output: torch.Tensor, - timesteps: List[int], + timesteps: torch.Tensor, sample: torch.Tensor, ) -> torch.Tensor: """ @@ -568,8 +598,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): direct output from learned diffusion model. - timesteps (`List[int]`): - current discrete timesteps in the diffusion chain. This is now a list of integers. + timesteps (`torch.Tensor`): + Current discrete timesteps in the diffusion chain. This is a tensor of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. @@ -583,7 +613,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: pass @@ -714,7 +747,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): return self.config.num_train_timesteps # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep - def previous_timestep(self, timestep): + def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]: """ Compute the previous timestep in the diffusion chain. @@ -723,7 +756,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index ebc3a33b27..7c2dfd8e50 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -34,7 +34,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index 66fb39c0bc..3e50ebbfe0 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -52,7 +52,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 990129f584..07cb64f32b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -34,7 +34,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 71b9960bf2..66398073b2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -31,6 +32,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DPMSolverMultistepSchedulerState: common: CommonSchedulerState @@ -171,6 +175,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): timestep_spacing: str = "linspace", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: @@ -203,7 +211,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ) def set_timesteps( - self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple + self, + state: DPMSolverMultistepSchedulerState, + num_inference_steps: int, + shape: Tuple, ) -> DPMSolverMultistepSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -301,10 +312,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): if self.config.thresholding: # Dynamic thresholding in https://huggingface.co/papers/2205.11487 dynamic_max_val = jnp.percentile( - jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) + jnp.abs(x0_pred), + self.config.dynamic_thresholding_ratio, + axis=tuple(range(1, x0_pred.ndim)), ) dynamic_max_val = jnp.maximum( - dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) + dynamic_max_val, + self.config.sample_max_value * jnp.ones_like(dynamic_max_val), ) x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred @@ -385,7 +399,11 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] + lambda_t, lambda_s0, lambda_s1 = ( + state.lambda_t[t], + state.lambda_t[s0], + state.lambda_t[s1], + ) alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 @@ -443,7 +461,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + t, s0, s1, s2 = ( + prev_timestep, + timestep_list[-1], + timestep_list[-2], + timestep_list[-3], + ) m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( state.lambda_t[t], @@ -615,7 +638,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) def scale_model_input( - self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DPMSolverMultistepSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index a9c4fe57b6..2da90d287c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -34,7 +34,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 5f9ce1393d..6f905a623d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -117,7 +117,7 @@ class BrownianTreeNoiseSampler: def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -131,8 +131,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index e92f880e5b..e9bf815aba 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -50,8 +50,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 0258ea7777..11fec60c9c 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -51,7 +51,7 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4238c976e4..8b141325fb 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -54,7 +54,7 @@ class EulerDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -68,8 +68,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py index 09341c909d..2bb6bf3558 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py @@ -19,6 +19,7 @@ import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -28,6 +29,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class EulerDiscreteSchedulerState: common: CommonSchedulerState @@ -40,9 +44,18 @@ class EulerDiscreteSchedulerState: @classmethod def create( - cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + sigmas: jnp.ndarray, ): - return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + return cls( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) @dataclass @@ -99,6 +112,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): timestep_spacing: str = "linspace", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState: @@ -146,7 +163,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): return sample def set_timesteps( - self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, + state: EulerDiscreteSchedulerState, + num_inference_steps: int, + shape: Tuple = (), ) -> EulerDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -159,7 +179,12 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ if self.config.timestep_spacing == "linspace": - timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + timesteps = jnp.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=self.dtype, + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // num_inference_steps timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 011f97ba5c..0c5e28ad06 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -51,7 +51,7 @@ class HeunDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 37849e28b2..ee49ae67b9 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -52,7 +52,7 @@ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 1c2791837c..6effb3699b 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -51,7 +51,7 @@ class KDPM2DiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index bacfbd6100..3f43a5fa99 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -22,10 +22,13 @@ import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, logging from .scheduling_utils_flax import FlaxSchedulerMixin +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class KarrasVeSchedulerState: # setable values @@ -102,7 +105,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): s_min: float = 0.05, s_max: float = 50, ): - pass + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) def create_state(self): return KarrasVeSchedulerState.create() diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index 66dedd5a6e..ada8806e8c 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -53,7 +53,7 @@ class LCMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -67,8 +67,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9fc9b1e64b..a1f9d27fd9 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -49,7 +49,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -63,8 +63,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 3fd4dc8a5d..4edb091348 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,6 +20,7 @@ import jax.numpy as jnp from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -29,6 +30,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class LMSDiscreteSchedulerState: common: CommonSchedulerState @@ -44,9 +48,18 @@ class LMSDiscreteSchedulerState: @classmethod def create( - cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + sigmas: jnp.ndarray, ): - return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + return cls( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) @dataclass @@ -101,6 +114,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: @@ -165,7 +182,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): return integrated_coeff def set_timesteps( - self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, + state: LMSDiscreteSchedulerState, + num_inference_steps: int, + shape: Tuple = (), ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -177,7 +197,12 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): the number of diffusion steps used when generating samples with a pre-trained model. """ - timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + timesteps = jnp.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=self.dtype, + ) low_idx = jnp.floor(timesteps).astype(jnp.int32) high_idx = jnp.ceil(timesteps).astype(jnp.int32) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index e95a374457..0820f5baa8 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -28,7 +28,7 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -42,8 +42,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 44bafccd55..bbef4649ec 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -31,6 +32,9 @@ from .scheduling_utils_flax import ( ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class PNDMSchedulerState: common: CommonSchedulerState @@ -131,6 +135,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype # For now we only support F-PNDM, i.e. the runge-kutta method @@ -190,7 +198,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): else: prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( - jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), + jnp.array( + [0, self.config.num_train_timesteps // num_inference_steps // 2], + dtype=jnp.int32, + ), self.pndm_order, ) @@ -218,7 +229,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ) def scale_model_input( - self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: PNDMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -320,7 +334,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ) diff_to_prev = jnp.where( - state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 + state.counter % 2, + 0, + self.config.num_train_timesteps // state.num_inference_steps // 2, ) prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] @@ -401,7 +417,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) timestep = jnp.where( - state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep + state.counter == 1, + timestep + self.config.num_train_timesteps // state.num_inference_steps, + timestep, ) # Reference: @@ -466,7 +484,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): # prev_sample -> x_(t−δ) alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + prev_timestep >= 0, + state.common.alphas_cumprod[prev_timestep], + state.final_alpha_cumprod, ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index fcebe7e21c..bec4a1bdf6 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -47,7 +47,7 @@ class RePaintSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -61,8 +61,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 7c679a255c..565fae1c0d 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -35,7 +35,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -49,8 +49,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index 09cd081462..f4fe6d8f6b 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -23,7 +23,15 @@ import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from ..utils import logging +from .scheduling_utils_flax import ( + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +logger = logging.get_logger(__name__) @flax.struct.dataclass @@ -95,7 +103,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): sampling_eps: float = 1e-5, correct_steps: int = 1, ): - pass + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) def create_state(self): state = ScoreSdeVeSchedulerState.create() @@ -108,7 +119,11 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ) def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None + self, + state: ScoreSdeVeSchedulerState, + num_inference_steps: int, + shape: Tuple = (), + sampling_eps: float = None, ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 7a385f6291..a1303436cd 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -52,7 +52,7 @@ class TCDSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index bdc4feb0b1..14b09277da 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -48,7 +48,7 @@ class UnCLIPSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -62,8 +62,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 0536e8d1ed..d8e24d1964 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -34,7 +34,7 @@ if is_scipy_available(): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -226,6 +226,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): time_shift_type: Literal["exponential"] = "exponential", sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, + shift_terminal: Optional[float] = None, ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -245,6 +246,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + if shift_terminal is not None and not use_flow_sigmas: + raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) @@ -313,8 +316,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self._begin_index = begin_index def set_timesteps( - self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None - ) -> None: + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -323,13 +330,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. mu (`float`, *optional*): Optional mu parameter for dynamic shifting when using exponential time shift type. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None: + if not self.config.use_flow_sigmas: + raise ValueError( + "Passing `sigmas` is only supported when `use_flow_sigmas=True`. " + "Please set `use_flow_sigmas=True` during scheduler initialization." + ) + num_inference_steps = len(sigmas) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 - if mu is not None: - assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" - self.config.flow_shift = np.exp(mu) if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) @@ -354,8 +372,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -375,6 +394,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_exponential_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -389,6 +410,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_beta_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -403,9 +426,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_flow_sigmas: - alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) - sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1] + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas) + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + eps = 1e-6 + if np.fabs(sigmas[0] - 1) < eps: + # to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update + sigmas[0] -= eps timesteps = (sigmas * self.config.num_train_timesteps).copy() if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] @@ -417,6 +449,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) else: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -446,6 +480,43 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py index 8a693e9c2d..df1dd2d987 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py @@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas def test_float16_inference(self): super().test_float16_inference(expected_max_diff=5e-1) + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + @is_flaky() def test_model_cpu_offload_forward_pass(self): super().test_inference_batch_single_identical(expected_max_diff=8e-4) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index 503fdb242d..d3bfa4b308 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase) def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=1e-2) + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + @slow @require_torch_accelerator diff --git a/utils/modular_auto_docstring.py b/utils/modular_auto_docstring.py new file mode 100644 index 0000000000..fc4a82f98e --- /dev/null +++ b/utils/modular_auto_docstring.py @@ -0,0 +1,352 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Auto Docstring Generator for Modular Pipeline Blocks + +This script scans Python files for classes that have `# auto_docstring` comment above them +and inserts/updates the docstring from the class's `doc` property. + +Run from the root of the repo: + python utils/modular_auto_docstring.py [path] [--fix_and_overwrite] + +Examples: + # Check for auto_docstring markers (will error if found without proper docstring) + python utils/modular_auto_docstring.py + + # Check specific directory + python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/ + + # Fix and overwrite the docstrings + python utils/modular_auto_docstring.py --fix_and_overwrite + +Usage in code: + # auto_docstring + class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + # docstring will be automatically inserted here + + @property + def doc(self): + return "Your docstring content..." +""" + +import argparse +import ast +import glob +import importlib +import os +import re +import subprocess +import sys + + +# All paths are set with the intent you should run this script from the root of the repo +DIFFUSERS_PATH = "src/diffusers" +REPO_PATH = "." + +# Pattern to match the auto_docstring comment +AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$") + + +def setup_diffusers_import(): + """Setup import path to use the local diffusers module.""" + src_path = os.path.join(REPO_PATH, "src") + if src_path not in sys.path: + sys.path.insert(0, src_path) + + +def get_module_from_filepath(filepath: str) -> str: + """Convert a filepath to a module name.""" + filepath = os.path.normpath(filepath) + + if filepath.startswith("src" + os.sep): + filepath = filepath[4:] + + if filepath.endswith(".py"): + filepath = filepath[:-3] + + module_name = filepath.replace(os.sep, ".") + return module_name + + +def load_module(filepath: str): + """Load a module from filepath.""" + setup_diffusers_import() + module_name = get_module_from_filepath(filepath) + + try: + module = importlib.import_module(module_name) + return module + except Exception as e: + print(f"Warning: Could not import module {module_name}: {e}") + return None + + +def get_doc_from_class(module, class_name: str) -> str: + """Get the doc property from an instantiated class.""" + if module is None: + return None + + cls = getattr(module, class_name, None) + if cls is None: + return None + + try: + instance = cls() + if hasattr(instance, "doc"): + return instance.doc + except Exception as e: + print(f"Warning: Could not instantiate {class_name}: {e}") + + return None + + +def find_auto_docstring_classes(filepath: str) -> list: + """ + Find all classes in a file that have # auto_docstring comment above them. + + Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line) + """ + with open(filepath, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Parse AST to find class locations and their docstrings + content = "".join(lines) + try: + tree = ast.parse(content) + except SyntaxError as e: + print(f"Syntax error in {filepath}: {e}") + return [] + + # Build a map of class_name -> (class_line, has_docstring, docstring_end_line) + class_info = {} + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + has_docstring = False + docstring_end_line = node.lineno # default to class line + + if node.body and isinstance(node.body[0], ast.Expr): + first_stmt = node.body[0] + if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str): + has_docstring = True + docstring_end_line = first_stmt.end_lineno or first_stmt.lineno + + class_info[node.name] = (node.lineno, has_docstring, docstring_end_line) + + # Now scan for # auto_docstring comments + classes_to_update = [] + + for i, line in enumerate(lines): + if AUTO_DOCSTRING_PATTERN.match(line): + # Found the marker, look for class definition on next non-empty, non-comment line + j = i + 1 + while j < len(lines): + next_line = lines[j].strip() + if next_line and not next_line.startswith("#"): + break + j += 1 + + if j < len(lines) and lines[j].strip().startswith("class "): + # Extract class name + match = re.match(r"class\s+(\w+)", lines[j].strip()) + if match: + class_name = match.group(1) + if class_name in class_info: + class_line, has_docstring, docstring_end_line = class_info[class_name] + classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line)) + + return classes_to_update + + +def strip_class_name_line(doc: str, class_name: str) -> str: + """Remove the 'class ClassName' line from the doc if present.""" + lines = doc.strip().split("\n") + if lines and lines[0].strip() == f"class {class_name}": + # Remove the class line and any blank line following it + lines = lines[1:] + while lines and not lines[0].strip(): + lines = lines[1:] + return "\n".join(lines) + + +def format_docstring(doc: str, indent: str = " ") -> str: + """Format a doc string as a properly indented docstring.""" + lines = doc.strip().split("\n") + + if len(lines) == 1: + return f'{indent}"""{lines[0]}"""\n' + else: + result = [f'{indent}"""\n'] + for line in lines: + if line.strip(): + result.append(f"{indent}{line}\n") + else: + result.append("\n") + result.append(f'{indent}"""\n') + return "".join(result) + + +def run_ruff_format(filepath: str): + """Run ruff check --fix, ruff format, and doc-builder style on a file to ensure consistent formatting.""" + try: + # First run ruff check --fix to fix any linting issues (including line length) + subprocess.run( + ["ruff", "check", "--fix", filepath], + check=False, # Don't fail if there are unfixable issues + capture_output=True, + text=True, + ) + # Then run ruff format for code formatting + subprocess.run( + ["ruff", "format", filepath], + check=True, + capture_output=True, + text=True, + ) + # Finally run doc-builder style for docstring formatting + subprocess.run( + ["doc-builder", "style", filepath, "--max_len", "119"], + check=False, # Don't fail if doc-builder has issues + capture_output=True, + text=True, + ) + print(f"Formatted {filepath}") + except subprocess.CalledProcessError as e: + print(f"Warning: formatting failed for {filepath}: {e.stderr}") + except FileNotFoundError as e: + print(f"Warning: tool not found ({e}). Skipping formatting.") + except Exception as e: + print(f"Warning: unexpected error formatting {filepath}: {e}") + + +def get_existing_docstring(lines: list, class_line: int, docstring_end_line: int) -> str: + """Extract the existing docstring content from lines.""" + # class_line is 1-indexed, docstring starts at class_line (0-indexed: class_line) + # and ends at docstring_end_line (1-indexed, inclusive) + docstring_lines = lines[class_line:docstring_end_line] + return "".join(docstring_lines) + + +def process_file(filepath: str, overwrite: bool = False) -> list: + """ + Process a file and find/insert docstrings for # auto_docstring marked classes. + + Returns list of classes that need updating. + """ + classes_to_update = find_auto_docstring_classes(filepath) + + if not classes_to_update: + return [] + + if not overwrite: + # Check mode: only verify that docstrings exist + # Content comparison is not reliable due to formatting differences + classes_needing_update = [] + for class_name, class_line, has_docstring, docstring_end_line in classes_to_update: + if not has_docstring: + # No docstring exists, needs update + classes_needing_update.append((filepath, class_name, class_line)) + return classes_needing_update + + # Load the module to get doc properties + module = load_module(filepath) + + with open(filepath, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Process in reverse order to maintain line numbers + updated = False + for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update): + doc = get_doc_from_class(module, class_name) + + if doc is None: + print(f"Warning: Could not get doc for {class_name} in {filepath}") + continue + + # Remove the "class ClassName" line since it's redundant in a docstring + doc = strip_class_name_line(doc, class_name) + + # Format the new docstring with 4-space indent + new_docstring = format_docstring(doc, " ") + + if has_docstring: + # Replace existing docstring (line after class definition to docstring_end_line) + # class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line + lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:] + else: + # Insert new docstring right after class definition line + # class_line is 1-indexed, so lines[class_line-1] is the class line + # Insert at position class_line (which is right after the class line) + lines = lines[:class_line] + [new_docstring] + lines[class_line:] + + updated = True + print(f"Updated docstring for {class_name} in {filepath}") + + if updated: + with open(filepath, "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines) + # Run ruff format to ensure consistent line wrapping + run_ruff_format(filepath) + + return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update] + + +def check_auto_docstrings(path: str = None, overwrite: bool = False): + """ + Check all files for # auto_docstring markers and optionally fix them. + """ + if path is None: + path = DIFFUSERS_PATH + + if os.path.isfile(path): + all_files = [path] + else: + all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True) + + all_markers = [] + + for filepath in all_files: + markers = process_file(filepath, overwrite) + all_markers.extend(markers) + + if not overwrite and len(all_markers) > 0: + message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers]) + raise ValueError( + f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n" + f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them." + ) + + if overwrite and len(all_markers) > 0: + print(f"\nProcessed {len(all_markers)} docstring(s).") + elif not overwrite and len(all_markers) == 0: + print("All # auto_docstring markers have valid docstrings.") + elif len(all_markers) == 0: + print("No # auto_docstring markers found.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Check and fix # auto_docstring markers in modular pipeline blocks", + ) + parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)") + parser.add_argument( + "--fix_and_overwrite", + action="store_true", + help="Whether to fix the docstrings by inserting them from doc property.", + ) + + args = parser.parse_args() + + check_auto_docstrings(args.path, args.fix_and_overwrite)