diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 3e5301987e..e418e125ae 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -15,13 +15,11 @@ import inspect from typing import Any, List, Optional, Tuple, Union -import PIL import torch from ...configuration_utils import FrozenDict -from ...guiders import ClassifierFreeGuidance from ...image_processor import VaeImageProcessor -from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel +from ...models import ControlNetModel, ControlNetUnionModel, UNet2DConditionModel from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel from ...schedulers import EulerDiscreteScheduler from ...utils import logging @@ -591,7 +589,11 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): type_hint=torch.Tensor, description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.", ), - InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs, can be generated in input step.", + ), ] @property @@ -618,7 +620,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): is_strength_max=True, add_noise=True, ): - batch_size = image_latents.shape[0] if isinstance(generator, list) and len(generator) != batch_size: @@ -640,46 +641,50 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): return latents, noise - - def check_inputs(self, batch_size, image_latents, mask, masked_image_latents): - if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size): - raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}") + raise ValueError( + f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}" + ) if not (mask.shape[0] == 1 or mask.shape[0] == batch_size): raise ValueError(f"mask should have have batch size 1 or {batch_size}, but got {mask.shape[0]}") - + if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size): - raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}") - - + raise ValueError( + f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}" + ) + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - + self.check_inputs( batch_size=block_state.batch_size, - image_latents=block_state.image_latents, - mask=block_state.mask, - masked_image_latents=block_state.masked_image_latents, - ) + image_latents=block_state.image_latents, + mask=block_state.mask, + masked_image_latents=block_state.masked_image_latents, + ) dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype device = components._execution_device - - final_batch_size = block_state.batch_size * block_state.num_images_per_prompt - + + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype) - block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1) + block_state.image_latents = block_state.image_latents.repeat( + final_batch_size // block_state.image_latents.shape[0], 1, 1, 1 + ) # 7. Prepare mask latent variables block_state.mask = block_state.mask.to(device=device, dtype=dtype) - block_state.mask = block_state.mask.repeat(final_batch_size//block_state.mask.shape[0], 1, 1, 1) + block_state.mask = block_state.mask.repeat(final_batch_size // block_state.mask.shape[0], 1, 1, 1) block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype) - block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1) - + block_state.masked_image_latents = block_state.masked_image_latents.repeat( + final_batch_size // block_state.masked_image_latents.shape[0], 1, 1, 1 + ) + if block_state.latent_timestep is not None: block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size) block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype) @@ -698,7 +703,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): add_noise=add_noise, ) - self.set_block_state(state, block_state) return components, state @@ -755,11 +759,13 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" ) ] - + def check_inputs(self, batch_size, image_latents): if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size): - raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}") - + raise ValueError( + f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}" + ) + @staticmethod def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None): if isinstance(generator, list) and len(generator) != image_latents.shape[0]: @@ -788,7 +794,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): final_batch_size = block_state.batch_size * block_state.num_images_per_prompt block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype) - block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1) + block_state.image_latents = block_state.image_latents.repeat( + final_batch_size // block_state.image_latents.shape[0], 1, 1, 1 + ) if block_state.latent_timestep is not None: block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size) @@ -935,7 +943,9 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: - return [ComponentSpec("unet", UNet2DConditionModel),] + return [ + ComponentSpec("unet", UNet2DConditionModel), + ] @property def description(self) -> str: @@ -976,7 +986,11 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), - InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs, can be generated in input step.", + ), ] @property @@ -1052,7 +1066,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - + device = components._execution_device dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype @@ -1087,7 +1101,9 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): text_encoder_projection_dim=text_encoder_projection_dim, ) block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to( + device=device + ) self.set_block_state(state, block_state) return components, state @@ -1102,7 +1118,9 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: - return [ComponentSpec("unet", UNet2DConditionModel),] + return [ + ComponentSpec("unet", UNet2DConditionModel), + ] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1196,7 +1214,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): original_size = block_state.original_size or (height, width) target_size = block_state.target_size or (height, width) - block_state.add_time_ids = self._get_add_time_ids( components, original_size, @@ -1218,7 +1235,9 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): block_state.negative_add_time_ids = block_state.add_time_ids block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to( + device=device + ) self.set_block_state(state, block_state) return components, state @@ -1229,7 +1248,9 @@ class StableDiffusionXLLCMStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: - return [ComponentSpec("unet", UNet2DConditionModel),] + return [ + ComponentSpec("unet", UNet2DConditionModel), + ] @property def description(self) -> str: @@ -1290,30 +1311,30 @@ class StableDiffusionXLLCMStep(PipelineBlock): assert emb.shape == (w.shape[0], embedding_dim) return emb - def check_input(self, unet, embedded_guidance_scale): - if embedded_guidance_scale is not None and unet.config.time_cond_proj_dim is None: - raise ValueError(f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None") + raise ValueError( + f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None" + ) if embedded_guidance_scale is None and unet.config.time_cond_proj_dim is not None: - raise ValueError(f"unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None") + raise ValueError("unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None") - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - + device = components._execution_device dtype = block_state.dtype if block_state.dtype is not None else components.unet.dtype final_batch_size = block_state.batch_size * block_state.num_images_per_prompt - # Optionally get Guidance Scale Embedding for LCM block_state.timestep_cond = None - - guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device) + + guidance_scale_tensor = ( + torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device) + ) block_state.timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim ).to(device=device, dtype=dtype) @@ -1476,9 +1497,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): if isinstance(controlnet, MultiControlNetModel) and isinstance( block_state.controlnet_conditioning_scale, float ): - block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len( - controlnet.nets - ) + block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) else: block_state.conditioning_scale = block_state.controlnet_conditioning_scale diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 6ed5d72bf6..f68b0be4b6 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -130,9 +130,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock): latents_std = ( torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - latents = ( - latents * latents_std / components.vae.config.scaling_factor + latents_mean - ) + latents = latents * latents_std / components.vae.config.scaling_factor + latents_mean else: latents = latents / components.vae.config.scaling_factor diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index ea95d850d0..bcfe4e9b1d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -15,6 +15,7 @@ from typing import List, Optional, Tuple import torch +from PIL import Image from transformers import ( CLIPImageProcessor, CLIPTextModel, @@ -39,8 +40,6 @@ from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import StableDiffusionXLModularPipeline -from PIL import Image - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -67,7 +66,6 @@ def get_clip_prompt_embeds( clip_skip=None, max_length=None, ): - text_inputs = tokenizer( prompt, padding="max_length", @@ -79,9 +77,7 @@ def get_clip_prompt_embeds( text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" @@ -99,24 +95,20 @@ def get_clip_prompt_embeds( else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - + return prompt_embeds, pooled_prompt_embeds # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components def encode_vae_image( - image: torch.Tensor, - vae: AutoencoderKL, - generator: torch.Generator, - dtype: torch.dtype, - device: torch.device + image: torch.Tensor, vae: AutoencoderKL, generator: torch.Generator, dtype: torch.dtype, device: torch.device ): latents_mean = latents_std = None if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) - + image = image.to(device=device, dtype=dtype) if vae.config.force_upcast: @@ -131,8 +123,7 @@ def encode_vae_image( if isinstance(generator, list): image_latents = [ - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: @@ -200,7 +191,11 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative IP adapter image embeddings"), + OutputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Negative IP adapter image embeddings", + ), ] @staticmethod @@ -229,7 +224,6 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): return image_embeds, uncond_image_embeds - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -245,7 +239,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): if len(block_state.ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(block_state.ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( @@ -333,20 +327,17 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): @staticmethod def check_inputs(prompt, prompt_2, negative_prompt, negative_prompt_2): - if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if prompt_2 is not None and ( - not isinstance(prompt_2, str) and not isinstance(prompt_2, list) - ): + + if prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - + if negative_prompt is not None and ( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - + if negative_prompt_2 is not None and ( not isinstance(negative_prompt_2, str) and not isinstance(negative_prompt_2, list) ): @@ -394,7 +385,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): """ dtype = components.text_encoder_2.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -421,7 +411,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): adjust_lora_scale_text_encoder(text_encoder, lora_scale) else: scale_lora_layers(text_encoder, lora_scale) - + # Define prompts prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 @@ -436,12 +426,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): prompt = components.maybe_convert_prompt(prompt, tokenizer) prompt_embeds, pooled_prompt_embeds = get_clip_prompt_embeds( - prompt=prompt, - text_encoder=text_encoder, - tokenizer=tokenizer, - device=device, + prompt=prompt, + text_encoder=text_encoder, + tokenizer=tokenizer, + device=device, clip_skip=clip_skip, - max_length=tokenizer.model_max_length + max_length=tokenizer.model_max_length, ) prompt_embeds_list.append(prompt_embeds) @@ -492,12 +482,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): max_length = prompt_embeds.shape[1] negative_prompt_embeds, negative_pooled_prompt_embeds = get_clip_prompt_embeds( - prompt=negative_prompt, - text_encoder=text_encoder, - tokenizer=tokenizer, - device=device, + prompt=negative_prompt, + text_encoder=text_encoder, + tokenizer=tokenizer, + device=device, clip_skip=None, - max_length=max_length + max_length=max_length, ) negative_prompt_embeds_list.append(negative_prompt_embeds) if negative_pooled_prompt_embeds.ndim == 2: @@ -523,8 +513,10 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: # Get inputs and intermediates block_state = self.get_block_state(state) - - self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2) + + self.check_inputs( + block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2 + ) device = components._execution_device @@ -608,11 +600,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): # Encode image into latents block_state.image_latents = encode_vae_image( - image=image, - vae=components.vae, - generator=block_state.generator, - dtype=dtype, - device=device + image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device ) self.set_block_state(state, block_state) @@ -681,14 +669,13 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): description="The crop coordinates to use for the preprocess/postprocess of the image and mask", ), OutputParam( - "mask", - type_hint=torch.Tensor, + "mask", + type_hint=torch.Tensor, description="The mask to apply on the latents for the inpainting generation.", ), ] - + def check_inputs(self, image, mask_image, padding_mask_crop): - if padding_mask_crop is not None and not isinstance(image, Image.Image): raise ValueError( f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." @@ -696,10 +683,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): if padding_mask_crop is not None and not isinstance(mask_image, Image.Image): raise ValueError( - f"The mask image should be a PIL image when inpainting mask crop, but is of type" - f" {type(mask_image)}." + f"The mask image should be a PIL image when inpainting mask crop, but is of type {type(mask_image)}." ) - + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -738,32 +724,24 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): resize_mode=resize_mode, crops_coords=block_state.crops_coords, ) - + masked_image = image * (mask_image < 0.5) # Prepare image latent variables block_state.image_latents = encode_vae_image( - image=image, - vae=components.vae, - generator=block_state.generator, - dtype=dtype, - device=device + image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device ) # Prepare masked image latent variables block_state.masked_image_latents = encode_vae_image( - image=masked_image, - vae=components.vae, - generator=block_state.generator, - dtype=dtype, - device=device + image=masked_image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device ) - + # resize mask to match the image latents _, _, height_latents, width_latents = block_state.image_latents.shape block_state.mask = torch.nn.functional.interpolate( - mask_image, - size=(height_latents, width_latents), + mask_image, + size=(height_latents, width_latents), ) block_state.mask = block_state.mask.to(dtype=dtype, device=device) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py index c38eb8c632..93998ab6cd 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py @@ -23,10 +23,10 @@ from .before_denoise import ( StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLInputStep, + StableDiffusionXLLCMStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, - StableDiffusionXLLCMStep, ) from .decoders import ( StableDiffusionXLDecodeStep, @@ -372,7 +372,6 @@ IP_ADAPTER_BLOCKS = InsertableDict( ) LCM_BLOCKS = InsertableDict( - [ ("lcm", StableDiffusionXLAutoLCMStep), ] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index f337a89a8b..c169786f16 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -89,7 +89,7 @@ class StableDiffusionXLModularPipeline( if hasattr(self, "vae") and self.vae is not None: num_channels_latents = self.vae.config.latent_channels return num_channels_latents - + @property def requires_unconditional_embeds(self): # by default, always prepare unconditional embeddings @@ -101,7 +101,7 @@ class StableDiffusionXLModularPipeline( elif hasattr(self, "guider") and self.guider is not None: requires_unconditional_embeds = self.guider.num_conditions > 1 - + elif not hasattr(self, "guider") or self.guider is None: requires_unconditional_embeds = False