mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -607,8 +607,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def prepare_latents(
|
||||
self,
|
||||
image_latents,
|
||||
scheduler,
|
||||
dtype,
|
||||
@@ -760,6 +760,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
|
||||
|
||||
@staticmethod
|
||||
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
|
||||
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
|
||||
raise ValueError(
|
||||
@@ -975,6 +976,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -992,7 +994,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
kwargs_type="guider_input_fields",
|
||||
description="The negative time ids to condition the denoising process",
|
||||
),
|
||||
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -1136,6 +1137,11 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1187,8 +1193,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
_, _, height_latents, width_latents = block_state.latents.shape
|
||||
height = height_latents * components.vae_scale_factor
|
||||
width = width_latents * components.vae_scale_factor
|
||||
original_size = block_state.original_size or (block_state.height, block_state.width)
|
||||
target_size = block_state.target_size or (block_state.height, block_state.width)
|
||||
original_size = block_state.original_size or (height, width)
|
||||
target_size = block_state.target_size or (height, width)
|
||||
|
||||
|
||||
block_state.add_time_ids = self._get_add_time_ids(
|
||||
|
||||
@@ -139,7 +139,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
block_state.images = components.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if block_state.needs_upcasting:
|
||||
if needs_upcasting:
|
||||
components.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
block_state.images = block_state.latents
|
||||
|
||||
@@ -39,6 +39,8 @@ from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -504,11 +506,11 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype, device=device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
|
||||
if requires_unconditional_embeds:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype, device=device)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
@@ -687,12 +689,12 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
|
||||
def check_inputs(self, image, mask_image, padding_mask_crop):
|
||||
|
||||
if padding_mask_crop is not None and not isinstance(image, PIL.Image.Image):
|
||||
if padding_mask_crop is not None and not isinstance(image, Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
|
||||
if padding_mask_crop is not None and not isinstance(mask_image, PIL.Image.Image):
|
||||
if padding_mask_crop is not None and not isinstance(mask_image, Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
@@ -707,10 +709,8 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
device = components._execution_device
|
||||
|
||||
if block_state.height is None:
|
||||
height = components.default_height
|
||||
if block_state.width is None:
|
||||
width = components.default_width
|
||||
height = block_state.height if block_state.height is not None else components.default_height
|
||||
width = block_state.width if block_state.width is not None else components.default_width
|
||||
|
||||
if block_state.padding_mask_crop is not None:
|
||||
block_state.crops_coords = components.mask_processor.get_crop_region(
|
||||
@@ -725,21 +725,21 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
block_state.image,
|
||||
height=height,
|
||||
width=width,
|
||||
crops_coords=crops_coords,
|
||||
crops_coords=block_state.crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
image = image.to(dtype=torch.float32)
|
||||
|
||||
mask = components.mask_processor.preprocess(
|
||||
mask_image = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=height,
|
||||
width=width,
|
||||
resize_mode=resize_mode,
|
||||
crops_coords=crops_coords,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
|
||||
masked_image = image * (block_state.mask_latents < 0.5)
|
||||
masked_image = image * (mask_image < 0.5)
|
||||
|
||||
# Prepare image latent variables
|
||||
block_state.image_latents = encode_vae_image(
|
||||
@@ -762,7 +762,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
# resize mask to match the image latents
|
||||
_, _, height_latents, width_latents = block_state.image_latents.shape
|
||||
block_state.mask = torch.nn.functional.interpolate(
|
||||
mask,
|
||||
mask_image,
|
||||
size=(height_latents, width_latents),
|
||||
)
|
||||
block_state.mask = block_state.mask.to(dtype=dtype, device=device)
|
||||
|
||||
@@ -95,11 +95,15 @@ class StableDiffusionXLModularPipeline(
|
||||
# by default, always prepare unconditional embeddings
|
||||
requires_unconditional_embeds = True
|
||||
|
||||
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is None:
|
||||
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is not None:
|
||||
# LCM
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
elif hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider.num_conditions > 1
|
||||
|
||||
elif not hasattr(self, "guider") or self.guider is None:
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user