diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 3068ffdcf2..1229bab169 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -143,23 +143,23 @@ def apply_rotary_emb_qwen( class QwenTimestepProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, additional_t_cond=False): + def __init__(self, embedding_dim, use_additional_t_cond=False): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.additional_t_cond = additional_t_cond - if additional_t_cond: + self.use_additional_t_cond = use_additional_t_cond + if use_additional_t_cond: self.addition_t_embedding = nn.Embedding(2, embedding_dim) - self.addition_t_embedding.weight.data.zero_() def forward(self, timestep, hidden_states, addition_t_cond=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) conditioning = timesteps_emb - if self.additional_t_cond: - assert addition_t_cond is not None, "When additional_t_cond is True, addition_t_cond must be provided." + if self.use_additional_t_cond: + if addition_t_cond is None: + raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.") addition_t_emb = self.addition_t_embedding(addition_t_cond) addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype) conditioning = conditioning + addition_t_emb @@ -291,9 +291,7 @@ class QwenEmbedLayer3DRope(nn.Module): ], dim=1, ) - self.rope_cache = {} - # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): @@ -703,7 +701,7 @@ class QwenImageTransformer2DModel( guidance_embeds: bool = False, # TODO: this should probably be removed axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), zero_cond_t: bool = False, - additional_t_cond: bool = False, + use_additional_t_cond: bool = False, use_layer3d_rope: bool = False, ): super().__init__() @@ -716,7 +714,7 @@ class QwenImageTransformer2DModel( self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) self.time_text_embed = QwenTimestepProjEmbeddings( - embedding_dim=self.inner_dim, additional_t_cond=additional_t_cond + embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond ) self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 0c637c5929..b716b2f4e4 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -18,14 +18,13 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch -from PIL import Image from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import QwenImageLoraLoaderMixin from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput @@ -152,6 +151,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions def calculate_dimensions(target_area, ratio): width = math.sqrt(target_area * ratio) height = width / ratio @@ -159,7 +159,7 @@ def calculate_dimensions(target_area, ratio): width = round(width / 32) * 32 height = round(height / 32) * 32 - return width, height, None + return width, height class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): @@ -266,6 +266,7 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): return prompt_embeds, encoder_attention_mask + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -296,6 +297,9 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -393,6 +397,7 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): return latents + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ @@ -416,59 +421,6 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): return image_latents - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." - deprecate( - "enable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." - deprecate( - "disable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." - deprecate( - "enable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." - deprecate( - "disable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.disable_tiling() - def prepare_latents( self, image, @@ -560,8 +512,6 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, true_cfg_scale: float = 4.0, - height: Optional[int] = None, - width: Optional[int] = None, layers: Optional[int] = 4, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, @@ -607,10 +557,6 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -663,7 +609,7 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - resolution (`int`, *optional*, defaults to 640) + resolution (`int`, *optional*, defaults to 640): using different bucket in (640, 1024) to determin the condition and output resolution cfg_normalize (`bool`, *optional*, defaults to `False`) whether enable cfg normalization. @@ -679,7 +625,7 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): """ image_size = image[0].size if isinstance(image, list) else image.size assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}" - calculated_width, calculated_height, _ = calculate_dimensions( + calculated_width, calculated_height = calculate_dimensions( resolution * resolution, image_size[0] / image_size[1] ) height = calculated_height @@ -718,9 +664,6 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if prompt is None or prompt == "" or prompt == " ": prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=device) - print(f"Generated prompt: {prompt}") - else: - print(f"User prompt: {prompt}") # 3. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -917,19 +860,21 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): latents.device, latents.dtype ) latents = latents / latents_std + latents_mean - latents = torch.unbind(latents, 2) - image = [] - for z in latents[1:]: - z = z.unsqueeze(2) # b c f h w - image.append(self.vae.decode(z, return_dict=False)[0]) - image = torch.cat(image, dim=2) # b c f h w - image = image.permute(0, 2, 3, 4, 1) # b f h w c - image = (image * 0.5 + 0.5).clamp(0, 1).cpu().float().numpy() - image = (image * 255).round().astype("uint8") + b, c, f, h, w = latents.shape + + latents = latents[:, :, 1:] # remove the first frame as it is the orgin input + + latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) + + image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w + + image = image.squeeze(2) + + image = self.image_processor.postprocess(image, output_type=output_type) images = [] - for layers in image: - images.append([Image.fromarray(layer) for layer in layers]) + for bidx in range(b): + images.append(image[bidx * f : (bidx + 1) * f]) # Offload all models self.maybe_free_model_hooks()