1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
yiyixuxu
2025-08-08 19:01:22 +02:00
parent ed881a15fd
commit 4b367e8edd
4 changed files with 31 additions and 21 deletions

View File

@@ -607,8 +607,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
),
]
@staticmethod
def prepare_latents(
self,
image_latents,
scheduler,
dtype,
@@ -760,6 +760,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
@staticmethod
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
raise ValueError(
@@ -975,6 +976,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
]
@property
@@ -992,7 +994,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
kwargs_type="guider_input_fields",
description="The negative time ids to condition the denoising process",
),
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
]
@staticmethod
@@ -1136,6 +1137,11 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam(
"dtype",
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@property
@@ -1187,8 +1193,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
_, _, height_latents, width_latents = block_state.latents.shape
height = height_latents * components.vae_scale_factor
width = width_latents * components.vae_scale_factor
original_size = block_state.original_size or (block_state.height, block_state.width)
target_size = block_state.target_size or (block_state.height, block_state.width)
original_size = block_state.original_size or (height, width)
target_size = block_state.target_size or (height, width)
block_state.add_time_ids = self._get_add_time_ids(

View File

@@ -139,7 +139,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
block_state.images = components.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if block_state.needs_upcasting:
if needs_upcasting:
components.vae.to(dtype=torch.float16)
else:
block_state.images = block_state.latents

View File

@@ -39,6 +39,8 @@ from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
from PIL import Image
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -504,11 +506,11 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0)
prompt_embeds = prompt_embeds.to(dtype, device=device)
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
if requires_unconditional_embeds:
negative_prompt_embeds = negative_prompt_embeds.to(dtype, device=device)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype=dtype, device=device)
for text_encoder in text_encoders:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
@@ -687,12 +689,12 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
def check_inputs(self, image, mask_image, padding_mask_crop):
if padding_mask_crop is not None and not isinstance(image, PIL.Image.Image):
if padding_mask_crop is not None and not isinstance(image, Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if padding_mask_crop is not None and not isinstance(mask_image, PIL.Image.Image):
if padding_mask_crop is not None and not isinstance(mask_image, Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
@@ -707,10 +709,8 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
device = components._execution_device
if block_state.height is None:
height = components.default_height
if block_state.width is None:
width = components.default_width
height = block_state.height if block_state.height is not None else components.default_height
width = block_state.width if block_state.width is not None else components.default_width
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(
@@ -725,21 +725,21 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
block_state.image,
height=height,
width=width,
crops_coords=crops_coords,
crops_coords=block_state.crops_coords,
resize_mode=resize_mode,
)
image = image.to(dtype=torch.float32)
mask = components.mask_processor.preprocess(
mask_image = components.mask_processor.preprocess(
block_state.mask_image,
height=height,
width=width,
resize_mode=resize_mode,
crops_coords=crops_coords,
crops_coords=block_state.crops_coords,
)
masked_image = image * (block_state.mask_latents < 0.5)
masked_image = image * (mask_image < 0.5)
# Prepare image latent variables
block_state.image_latents = encode_vae_image(
@@ -762,7 +762,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
# resize mask to match the image latents
_, _, height_latents, width_latents = block_state.image_latents.shape
block_state.mask = torch.nn.functional.interpolate(
mask,
mask_image,
size=(height_latents, width_latents),
)
block_state.mask = block_state.mask.to(dtype=dtype, device=device)

View File

@@ -95,11 +95,15 @@ class StableDiffusionXLModularPipeline(
# by default, always prepare unconditional embeddings
requires_unconditional_embeds = True
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is None:
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is not None:
# LCM
requires_unconditional_embeds = False
elif hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider.num_conditions > 1
elif not hasattr(self, "guider") or self.guider is None:
requires_unconditional_embeds = False
return requires_unconditional_embeds