mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
style
This commit is contained in:
@@ -15,13 +15,11 @@
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
|
||||
from ...models import ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
|
||||
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
@@ -591,7 +589,11 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
type_hint=torch.Tensor,
|
||||
description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
|
||||
InputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -618,7 +620,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
):
|
||||
|
||||
batch_size = image_latents.shape[0]
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
@@ -640,46 +641,50 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
|
||||
return latents, noise
|
||||
|
||||
|
||||
|
||||
def check_inputs(self, batch_size, image_latents, mask, masked_image_latents):
|
||||
|
||||
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
|
||||
raise ValueError(
|
||||
f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}"
|
||||
)
|
||||
|
||||
if not (mask.shape[0] == 1 or mask.shape[0] == batch_size):
|
||||
raise ValueError(f"mask should have have batch size 1 or {batch_size}, but got {mask.shape[0]}")
|
||||
|
||||
|
||||
if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}")
|
||||
|
||||
|
||||
raise ValueError(
|
||||
f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
|
||||
self.check_inputs(
|
||||
batch_size=block_state.batch_size,
|
||||
image_latents=block_state.image_latents,
|
||||
mask=block_state.mask,
|
||||
masked_image_latents=block_state.masked_image_latents,
|
||||
)
|
||||
image_latents=block_state.image_latents,
|
||||
mask=block_state.mask,
|
||||
masked_image_latents=block_state.masked_image_latents,
|
||||
)
|
||||
|
||||
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
|
||||
device = components._execution_device
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
|
||||
block_state.image_latents = block_state.image_latents.repeat(
|
||||
final_batch_size // block_state.image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask = block_state.mask.to(device=device, dtype=dtype)
|
||||
block_state.mask = block_state.mask.repeat(final_batch_size//block_state.mask.shape[0], 1, 1, 1)
|
||||
block_state.mask = block_state.mask.repeat(final_batch_size // block_state.mask.shape[0], 1, 1, 1)
|
||||
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.repeat(
|
||||
final_batch_size // block_state.masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
|
||||
if block_state.latent_timestep is not None:
|
||||
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
|
||||
block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype)
|
||||
@@ -698,7 +703,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
add_noise=add_noise,
|
||||
)
|
||||
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -755,11 +759,13 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def check_inputs(self, batch_size, image_latents):
|
||||
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
|
||||
|
||||
raise ValueError(
|
||||
f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
|
||||
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
|
||||
@@ -788,7 +794,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
|
||||
block_state.image_latents = block_state.image_latents.repeat(
|
||||
final_batch_size // block_state.image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
|
||||
if block_state.latent_timestep is not None:
|
||||
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
|
||||
@@ -935,7 +943,9 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("unet", UNet2DConditionModel),]
|
||||
return [
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -976,7 +986,11 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
|
||||
InputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1052,7 +1066,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
|
||||
device = components._execution_device
|
||||
dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype
|
||||
|
||||
@@ -1087,7 +1101,9 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(
|
||||
device=device
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -1102,7 +1118,9 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("unet", UNet2DConditionModel),]
|
||||
return [
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
@@ -1196,7 +1214,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
original_size = block_state.original_size or (height, width)
|
||||
target_size = block_state.target_size or (height, width)
|
||||
|
||||
|
||||
block_state.add_time_ids = self._get_add_time_ids(
|
||||
components,
|
||||
original_size,
|
||||
@@ -1218,7 +1235,9 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
block_state.negative_add_time_ids = block_state.add_time_ids
|
||||
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(
|
||||
device=device
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -1229,7 +1248,9 @@ class StableDiffusionXLLCMStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("unet", UNet2DConditionModel),]
|
||||
return [
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -1290,30 +1311,30 @@ class StableDiffusionXLLCMStep(PipelineBlock):
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def check_input(self, unet, embedded_guidance_scale):
|
||||
|
||||
if embedded_guidance_scale is not None and unet.config.time_cond_proj_dim is None:
|
||||
raise ValueError(f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None")
|
||||
raise ValueError(
|
||||
f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None"
|
||||
)
|
||||
|
||||
if embedded_guidance_scale is None and unet.config.time_cond_proj_dim is not None:
|
||||
raise ValueError(f"unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None")
|
||||
raise ValueError("unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
|
||||
device = components._execution_device
|
||||
dtype = block_state.dtype if block_state.dtype is not None else components.unet.dtype
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
|
||||
# Optionally get Guidance Scale Embedding for LCM
|
||||
block_state.timestep_cond = None
|
||||
|
||||
guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
|
||||
|
||||
guidance_scale_tensor = (
|
||||
torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
|
||||
)
|
||||
block_state.timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=dtype)
|
||||
@@ -1476,9 +1497,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(
|
||||
block_state.controlnet_conditioning_scale, float
|
||||
):
|
||||
block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(
|
||||
controlnet.nets
|
||||
)
|
||||
block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
else:
|
||||
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
|
||||
|
||||
|
||||
@@ -130,9 +130,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
latents_std = (
|
||||
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = (
|
||||
latents * latents_std / components.vae.config.scaling_factor + latents_mean
|
||||
)
|
||||
latents = latents * latents_std / components.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / components.vae.config.scaling_factor
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
@@ -39,8 +40,6 @@ from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -67,7 +66,6 @@ def get_clip_prompt_embeds(
|
||||
clip_skip=None,
|
||||
max_length=None,
|
||||
):
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -79,9 +77,7 @@ def get_clip_prompt_embeds(
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -99,24 +95,20 @@ def get_clip_prompt_embeds(
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||
def encode_vae_image(
|
||||
image: torch.Tensor,
|
||||
vae: AutoencoderKL,
|
||||
generator: torch.Generator,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device
|
||||
image: torch.Tensor, vae: AutoencoderKL, generator: torch.Generator, dtype: torch.dtype, device: torch.device
|
||||
):
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if vae.config.force_upcast:
|
||||
@@ -131,8 +123,7 @@ def encode_vae_image(
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
@@ -200,7 +191,11 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="IP adapter image embeddings"),
|
||||
OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative IP adapter image embeddings"),
|
||||
OutputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Negative IP adapter image embeddings",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -229,7 +224,6 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
@@ -245,7 +239,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
|
||||
if len(block_state.ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(block_state.ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
@@ -333,20 +327,17 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(prompt, prompt_2, negative_prompt, negative_prompt_2):
|
||||
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_2 is not None and (
|
||||
not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
|
||||
):
|
||||
|
||||
if prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
|
||||
if negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
|
||||
if negative_prompt_2 is not None and (
|
||||
not isinstance(negative_prompt_2, str) and not isinstance(negative_prompt_2, list)
|
||||
):
|
||||
@@ -394,7 +385,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
"""
|
||||
dtype = components.text_encoder_2.dtype
|
||||
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
batch_size = len(prompt)
|
||||
@@ -421,7 +411,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
adjust_lora_scale_text_encoder(text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(text_encoder, lora_scale)
|
||||
|
||||
|
||||
# Define prompts
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
@@ -436,12 +426,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
prompt = components.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
prompt_embeds, pooled_prompt_embeds = get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
prompt=prompt,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
clip_skip=clip_skip,
|
||||
max_length=tokenizer.model_max_length
|
||||
max_length=tokenizer.model_max_length,
|
||||
)
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -492,12 +482,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
negative_prompt_embeds, negative_pooled_prompt_embeds = get_clip_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
prompt=negative_prompt,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
clip_skip=None,
|
||||
max_length=max_length
|
||||
max_length=max_length,
|
||||
)
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
if negative_pooled_prompt_embeds.ndim == 2:
|
||||
@@ -523,8 +513,10 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
# Get inputs and intermediates
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2)
|
||||
|
||||
self.check_inputs(
|
||||
block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2
|
||||
)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
@@ -608,11 +600,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
|
||||
# Encode image into latents
|
||||
block_state.image_latents = encode_vae_image(
|
||||
image=image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
@@ -681,14 +669,13 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
|
||||
),
|
||||
OutputParam(
|
||||
"mask",
|
||||
type_hint=torch.Tensor,
|
||||
"mask",
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask to apply on the latents for the inpainting generation.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def check_inputs(self, image, mask_image, padding_mask_crop):
|
||||
|
||||
if padding_mask_crop is not None and not isinstance(image, Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
@@ -696,10 +683,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
|
||||
if padding_mask_crop is not None and not isinstance(mask_image, Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type {type(mask_image)}."
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
@@ -738,32 +724,24 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
resize_mode=resize_mode,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
|
||||
|
||||
masked_image = image * (mask_image < 0.5)
|
||||
|
||||
# Prepare image latent variables
|
||||
block_state.image_latents = encode_vae_image(
|
||||
image=image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Prepare masked image latent variables
|
||||
block_state.masked_image_latents = encode_vae_image(
|
||||
image=masked_image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
image=masked_image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
|
||||
# resize mask to match the image latents
|
||||
_, _, height_latents, width_latents = block_state.image_latents.shape
|
||||
block_state.mask = torch.nn.functional.interpolate(
|
||||
mask_image,
|
||||
size=(height_latents, width_latents),
|
||||
mask_image,
|
||||
size=(height_latents, width_latents),
|
||||
)
|
||||
block_state.mask = block_state.mask.to(dtype=dtype, device=device)
|
||||
|
||||
|
||||
@@ -23,10 +23,10 @@ from .before_denoise import (
|
||||
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||
StableDiffusionXLInpaintPrepareLatentsStep,
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLLCMStep,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
StableDiffusionXLLCMStep,
|
||||
)
|
||||
from .decoders import (
|
||||
StableDiffusionXLDecodeStep,
|
||||
@@ -372,7 +372,6 @@ IP_ADAPTER_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
LCM_BLOCKS = InsertableDict(
|
||||
|
||||
[
|
||||
("lcm", StableDiffusionXLAutoLCMStep),
|
||||
]
|
||||
|
||||
@@ -89,7 +89,7 @@ class StableDiffusionXLModularPipeline(
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
# by default, always prepare unconditional embeddings
|
||||
@@ -101,7 +101,7 @@ class StableDiffusionXLModularPipeline(
|
||||
|
||||
elif hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider.num_conditions > 1
|
||||
|
||||
|
||||
elif not hasattr(self, "guider") or self.guider is None:
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user