1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[qwen-image] fix pr comments

This commit is contained in:
naykun
2025-12-17 18:14:05 +08:00
parent e630e6eca4
commit f3c62427cb
2 changed files with 31 additions and 88 deletions

View File

@@ -143,23 +143,23 @@ def apply_rotary_emb_qwen(
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, additional_t_cond=False):
def __init__(self, embedding_dim, use_additional_t_cond=False):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.additional_t_cond = additional_t_cond
if additional_t_cond:
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
self.addition_t_embedding.weight.data.zero_()
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
conditioning = timesteps_emb
if self.additional_t_cond:
assert addition_t_cond is not None, "When additional_t_cond is True, addition_t_cond must be provided."
if self.use_additional_t_cond:
if addition_t_cond is None:
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
conditioning = conditioning + addition_t_emb
@@ -291,9 +291,7 @@ class QwenEmbedLayer3DRope(nn.Module):
],
dim=1,
)
self.rope_cache = {}
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
@@ -703,7 +701,7 @@ class QwenImageTransformer2DModel(
guidance_embeds: bool = False, # TODO: this should probably be removed
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
zero_cond_t: bool = False,
additional_t_cond: bool = False,
use_additional_t_cond: bool = False,
use_layer3d_rope: bool = False,
):
super().__init__()
@@ -716,7 +714,7 @@ class QwenImageTransformer2DModel(
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, additional_t_cond=additional_t_cond
embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond
)
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)

View File

@@ -18,14 +18,13 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import QwenImageLoraLoaderMixin
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
@@ -152,6 +151,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions
def calculate_dimensions(target_area, ratio):
width = math.sqrt(target_area * ratio)
height = width / ratio
@@ -159,7 +159,7 @@ def calculate_dimensions(target_area, ratio):
width = round(width / 32) * 32
height = round(height / 32) * 32
return width, height, None
return width, height
class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
@@ -266,6 +266,7 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
return prompt_embeds, encoder_attention_mask
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -296,6 +297,9 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -393,6 +397,7 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
return latents
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
@@ -416,59 +421,6 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
return image_latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
deprecate(
"disable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
deprecate(
"enable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
deprecate(
"disable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.disable_tiling()
def prepare_latents(
self,
image,
@@ -560,8 +512,6 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
true_cfg_scale: float = 4.0,
height: Optional[int] = None,
width: Optional[int] = None,
layers: Optional[int] = 4,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
@@ -607,10 +557,6 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
lower image quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -663,7 +609,7 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
resolution (`int`, *optional*, defaults to 640)
resolution (`int`, *optional*, defaults to 640):
using different bucket in (640, 1024) to determin the condition and output resolution
cfg_normalize (`bool`, *optional*, defaults to `False`)
whether enable cfg normalization.
@@ -679,7 +625,7 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
"""
image_size = image[0].size if isinstance(image, list) else image.size
assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}"
calculated_width, calculated_height, _ = calculate_dimensions(
calculated_width, calculated_height = calculate_dimensions(
resolution * resolution, image_size[0] / image_size[1]
)
height = calculated_height
@@ -718,9 +664,6 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if prompt is None or prompt == "" or prompt == " ":
prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=device)
print(f"Generated prompt: {prompt}")
else:
print(f"User prompt: {prompt}")
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -917,19 +860,21 @@ class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
latents = torch.unbind(latents, 2)
image = []
for z in latents[1:]:
z = z.unsqueeze(2) # b c f h w
image.append(self.vae.decode(z, return_dict=False)[0])
image = torch.cat(image, dim=2) # b c f h w
image = image.permute(0, 2, 3, 4, 1) # b f h w c
image = (image * 0.5 + 0.5).clamp(0, 1).cpu().float().numpy()
image = (image * 255).round().astype("uint8")
b, c, f, h, w = latents.shape
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
image = image.squeeze(2)
image = self.image_processor.postprocess(image, output_type=output_type)
images = []
for layers in image:
images.append([Image.fromarray(layer) for layer in layers])
for bidx in range(b):
images.append(image[bidx * f : (bidx + 1) * f])
# Offload all models
self.maybe_free_model_hooks()