1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into requirements-custom-blocks

This commit is contained in:
Sayak Paul
2025-09-02 10:04:22 +05:30
committed by GitHub
13 changed files with 1195 additions and 27 deletions

View File

@@ -120,6 +120,12 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
- all
- __call__
## QwenImageEditInpaintPipeline
[[autodoc]] QwenImageEditInpaintPipeline
- all
- __call__
## QwenImaggeControlNetPipeline
- all
- __call__

View File

@@ -223,7 +223,7 @@ from diffusers.image_processor import VaeImageProcessor
import torch
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
with torch.no_grad():

View File

@@ -223,7 +223,7 @@ from diffusers.image_processor import VaeImageProcessor
import torch
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
with torch.no_grad():

View File

@@ -1270,6 +1270,7 @@ def main(args):
subfolder="transformer",
revision=args.revision,
variant=args.variant,
torch_dtype=torch_dtype,
)
pipeline = FluxKontextPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1292,7 +1293,8 @@ def main(args):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
images = pipeline(example["prompt"]).images
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
images = pipeline(prompt=example["prompt"]).images
for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
@@ -1899,6 +1901,10 @@ def main(args):
device=accelerator.device,
prompt=args.instance_prompt,
)
else:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
# Convert images to latent space
if args.cache_latents:

View File

@@ -494,6 +494,7 @@ else:
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
"QwenImageEditPipeline",
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
@@ -1134,6 +1135,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,

View File

@@ -2129,6 +2129,10 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_diffusion_model:
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
@@ -2201,29 +2205,44 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
up_key = ".lora_up.weight"
a_key = ".lora_A.weight"
b_key = ".lora_B.weight"
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
alpha_key = k.replace(down_key, ".alpha")
if has_non_diffusers_lora_id:
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(down_key, up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
alpha_key = k.replace(down_key, ".alpha")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(down_key, up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
# Already in diffusers format (lora_A/lora_B), just pop
elif has_diffusers_lora_id:
for k in all_keys:
if a_key in k or b_key in k:
converted_state_dict[k] = state_dict.pop(k)
elif ".alpha" in k:
state_dict.pop(k)
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

View File

@@ -6684,7 +6684,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_alphas_in_sd or has_lora_unet:
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict

View File

@@ -955,12 +955,13 @@ def _native_npu_attention(
dropout_p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
return npu_fusion_attention(
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
query.size(1), # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
@@ -969,6 +970,8 @@ def _native_npu_attention(
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
return out
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853

View File

@@ -393,6 +393,7 @@ else:
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImageEditPipeline",
"QwenImageEditInpaintPipeline",
"QwenImageControlNetPipeline",
]
try:
@@ -714,6 +715,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import (
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,

View File

@@ -26,6 +26,7 @@ else:
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
_import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"]
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
@@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_qwenimage import QwenImagePipeline
from .pipeline_qwenimage_controlnet import QwenImageControlNetPipeline
from .pipeline_qwenimage_edit import QwenImageEditPipeline
from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
else:

View File

@@ -551,6 +551,12 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Function invoked when calling the pipeline for generation.
Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.

File diff suppressed because it is too large Load Diff

View File

@@ -1772,6 +1772,21 @@ class QwenImageControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class QwenImageEditInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class QwenImageEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]