1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
yiyixuxu
2025-08-08 19:02:45 +02:00
parent 4b367e8edd
commit 1bea2d8eea
5 changed files with 116 additions and 122 deletions

View File

@@ -15,13 +15,11 @@
import inspect
from typing import Any, List, Optional, Tuple, Union
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
from ...models import ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import logging
@@ -591,7 +589,11 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
type_hint=torch.Tensor,
description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
InputParam(
"dtype",
type_hint=torch.dtype,
description="The dtype of the model inputs, can be generated in input step.",
),
]
@property
@@ -618,7 +620,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
is_strength_max=True,
add_noise=True,
):
batch_size = image_latents.shape[0]
if isinstance(generator, list) and len(generator) != batch_size:
@@ -640,46 +641,50 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
return latents, noise
def check_inputs(self, batch_size, image_latents, mask, masked_image_latents):
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
raise ValueError(
f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}"
)
if not (mask.shape[0] == 1 or mask.shape[0] == batch_size):
raise ValueError(f"mask should have have batch size 1 or {batch_size}, but got {mask.shape[0]}")
if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size):
raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}")
raise ValueError(
f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}"
)
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
batch_size=block_state.batch_size,
image_latents=block_state.image_latents,
mask=block_state.mask,
masked_image_latents=block_state.masked_image_latents,
)
image_latents=block_state.image_latents,
mask=block_state.mask,
masked_image_latents=block_state.masked_image_latents,
)
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
device = components._execution_device
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
block_state.image_latents = block_state.image_latents.repeat(
final_batch_size // block_state.image_latents.shape[0], 1, 1, 1
)
# 7. Prepare mask latent variables
block_state.mask = block_state.mask.to(device=device, dtype=dtype)
block_state.mask = block_state.mask.repeat(final_batch_size//block_state.mask.shape[0], 1, 1, 1)
block_state.mask = block_state.mask.repeat(final_batch_size // block_state.mask.shape[0], 1, 1, 1)
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1)
block_state.masked_image_latents = block_state.masked_image_latents.repeat(
final_batch_size // block_state.masked_image_latents.shape[0], 1, 1, 1
)
if block_state.latent_timestep is not None:
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype)
@@ -698,7 +703,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
add_noise=add_noise,
)
self.set_block_state(state, block_state)
return components, state
@@ -755,11 +759,13 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
)
]
def check_inputs(self, batch_size, image_latents):
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
raise ValueError(
f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}"
)
@staticmethod
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
@@ -788,7 +794,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
block_state.image_latents = block_state.image_latents.repeat(
final_batch_size // block_state.image_latents.shape[0], 1, 1, 1
)
if block_state.latent_timestep is not None:
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
@@ -935,7 +943,9 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("unet", UNet2DConditionModel),]
return [
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
@@ -976,7 +986,11 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
InputParam(
"dtype",
type_hint=torch.dtype,
description="The dtype of the model inputs, can be generated in input step.",
),
]
@property
@@ -1052,7 +1066,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype
@@ -1087,7 +1101,9 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
text_encoder_projection_dim=text_encoder_projection_dim,
)
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(
device=device
)
self.set_block_state(state, block_state)
return components, state
@@ -1102,7 +1118,9 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("unet", UNet2DConditionModel),]
return [
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
@@ -1196,7 +1214,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
original_size = block_state.original_size or (height, width)
target_size = block_state.target_size or (height, width)
block_state.add_time_ids = self._get_add_time_ids(
components,
original_size,
@@ -1218,7 +1235,9 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
block_state.negative_add_time_ids = block_state.add_time_ids
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(
device=device
)
self.set_block_state(state, block_state)
return components, state
@@ -1229,7 +1248,9 @@ class StableDiffusionXLLCMStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("unet", UNet2DConditionModel),]
return [
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
@@ -1290,30 +1311,30 @@ class StableDiffusionXLLCMStep(PipelineBlock):
assert emb.shape == (w.shape[0], embedding_dim)
return emb
def check_input(self, unet, embedded_guidance_scale):
if embedded_guidance_scale is not None and unet.config.time_cond_proj_dim is None:
raise ValueError(f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None")
raise ValueError(
f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None"
)
if embedded_guidance_scale is None and unet.config.time_cond_proj_dim is not None:
raise ValueError(f"unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None")
raise ValueError("unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None")
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = block_state.dtype if block_state.dtype is not None else components.unet.dtype
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
# Optionally get Guidance Scale Embedding for LCM
block_state.timestep_cond = None
guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
guidance_scale_tensor = (
torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
)
block_state.timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
).to(device=device, dtype=dtype)
@@ -1476,9 +1497,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
if isinstance(controlnet, MultiControlNetModel) and isinstance(
block_state.controlnet_conditioning_scale, float
):
block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(
controlnet.nets
)
block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
else:
block_state.conditioning_scale = block_state.controlnet_conditioning_scale

View File

@@ -130,9 +130,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = (
latents * latents_std / components.vae.config.scaling_factor + latents_mean
)
latents = latents * latents_std / components.vae.config.scaling_factor + latents_mean
else:
latents = latents / components.vae.config.scaling_factor

View File

@@ -15,6 +15,7 @@
from typing import List, Optional, Tuple
import torch
from PIL import Image
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
@@ -39,8 +40,6 @@ from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
from PIL import Image
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -67,7 +66,6 @@ def get_clip_prompt_embeds(
clip_skip=None,
max_length=None,
):
text_inputs = tokenizer(
prompt,
padding="max_length",
@@ -79,9 +77,7 @@ def get_clip_prompt_embeds(
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
@@ -99,24 +95,20 @@ def get_clip_prompt_embeds(
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
return prompt_embeds, pooled_prompt_embeds
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
def encode_vae_image(
image: torch.Tensor,
vae: AutoencoderKL,
generator: torch.Generator,
dtype: torch.dtype,
device: torch.device
image: torch.Tensor, vae: AutoencoderKL, generator: torch.Generator, dtype: torch.dtype, device: torch.device
):
latents_mean = latents_std = None
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
image = image.to(device=device, dtype=dtype)
if vae.config.force_upcast:
@@ -131,8 +123,7 @@ def encode_vae_image(
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
@@ -200,7 +191,11 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="IP adapter image embeddings"),
OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative IP adapter image embeddings"),
OutputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative IP adapter image embeddings",
),
]
@staticmethod
@@ -229,7 +224,6 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
return image_embeds, uncond_image_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
@@ -245,7 +239,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
if len(block_state.ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(block_state.ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
@@ -333,20 +327,17 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
@staticmethod
def check_inputs(prompt, prompt_2, negative_prompt, negative_prompt_2):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_2 is not None and (
not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
):
if prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if negative_prompt_2 is not None and (
not isinstance(negative_prompt_2, str) and not isinstance(negative_prompt_2, list)
):
@@ -394,7 +385,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
"""
dtype = components.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
@@ -421,7 +411,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
adjust_lora_scale_text_encoder(text_encoder, lora_scale)
else:
scale_lora_layers(text_encoder, lora_scale)
# Define prompts
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
@@ -436,12 +426,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
prompt_embeds, pooled_prompt_embeds = get_clip_prompt_embeds(
prompt=prompt,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=device,
prompt=prompt,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=device,
clip_skip=clip_skip,
max_length=tokenizer.model_max_length
max_length=tokenizer.model_max_length,
)
prompt_embeds_list.append(prompt_embeds)
@@ -492,12 +482,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
max_length = prompt_embeds.shape[1]
negative_prompt_embeds, negative_pooled_prompt_embeds = get_clip_prompt_embeds(
prompt=negative_prompt,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=device,
prompt=negative_prompt,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=device,
clip_skip=None,
max_length=max_length
max_length=max_length,
)
negative_prompt_embeds_list.append(negative_prompt_embeds)
if negative_pooled_prompt_embeds.ndim == 2:
@@ -523,8 +513,10 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2)
self.check_inputs(
block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2
)
device = components._execution_device
@@ -608,11 +600,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
# Encode image into latents
block_state.image_latents = encode_vae_image(
image=image,
vae=components.vae,
generator=block_state.generator,
dtype=dtype,
device=device
image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
)
self.set_block_state(state, block_state)
@@ -681,14 +669,13 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
),
OutputParam(
"mask",
type_hint=torch.Tensor,
"mask",
type_hint=torch.Tensor,
description="The mask to apply on the latents for the inpainting generation.",
),
]
def check_inputs(self, image, mask_image, padding_mask_crop):
if padding_mask_crop is not None and not isinstance(image, Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
@@ -696,10 +683,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
if padding_mask_crop is not None and not isinstance(mask_image, Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
f"The mask image should be a PIL image when inpainting mask crop, but is of type {type(mask_image)}."
)
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
@@ -738,32 +724,24 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
resize_mode=resize_mode,
crops_coords=block_state.crops_coords,
)
masked_image = image * (mask_image < 0.5)
# Prepare image latent variables
block_state.image_latents = encode_vae_image(
image=image,
vae=components.vae,
generator=block_state.generator,
dtype=dtype,
device=device
image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
)
# Prepare masked image latent variables
block_state.masked_image_latents = encode_vae_image(
image=masked_image,
vae=components.vae,
generator=block_state.generator,
dtype=dtype,
device=device
image=masked_image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
)
# resize mask to match the image latents
_, _, height_latents, width_latents = block_state.image_latents.shape
block_state.mask = torch.nn.functional.interpolate(
mask_image,
size=(height_latents, width_latents),
mask_image,
size=(height_latents, width_latents),
)
block_state.mask = block_state.mask.to(dtype=dtype, device=device)

View File

@@ -23,10 +23,10 @@ from .before_denoise import (
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLInputStep,
StableDiffusionXLLCMStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLLCMStep,
)
from .decoders import (
StableDiffusionXLDecodeStep,
@@ -372,7 +372,6 @@ IP_ADAPTER_BLOCKS = InsertableDict(
)
LCM_BLOCKS = InsertableDict(
[
("lcm", StableDiffusionXLAutoLCMStep),
]

View File

@@ -89,7 +89,7 @@ class StableDiffusionXLModularPipeline(
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
return num_channels_latents
@property
def requires_unconditional_embeds(self):
# by default, always prepare unconditional embeddings
@@ -101,7 +101,7 @@ class StableDiffusionXLModularPipeline(
elif hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider.num_conditions > 1
elif not hasattr(self, "guider") or self.guider is None:
requires_unconditional_embeds = False