From dc6a4d4cb4e7d46e3ac36e1c18be14ede174e104 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 6 Aug 2025 04:57:31 +0200 Subject: [PATCH] more --- .../modular_pipeline_utils.py | 5 +- .../stable_diffusion_xl/before_denoise.py | 34 +- .../stable_diffusion_xl/encoders.py | 703 ++++++++---------- .../stable_diffusion_xl/modular_pipeline.py | 10 + 4 files changed, 329 insertions(+), 423 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index f2fc015e94..2547360aa2 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -91,7 +91,10 @@ class ComponentSpec: type_hint: Optional[Type] = None description: Optional[str] = None config: Optional[FrozenDict] = None - # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + # YiYi TODO: currently required is only used to mark optional components that the block can run without, in the future: + # 1. the spec for an optional component should has lower priority when combined in sequential/auto blocks + # 2. should not need to define default_creation_method for optional components + required: bool = True repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) subfolder: Optional[str] = field(default="", metadata={"loading": True}) variant: Optional[str] = field(default=None, metadata={"loading": True}) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index b367fc7c62..e7b7d2550e 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -418,21 +418,21 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, - block_state.num_inference_steps, - block_state.device, - block_state.timesteps, - block_state.sigmas, + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + timesteps=block_state.timesteps, + sigmas=block_state.sigmas, ) def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( - components, - block_state.num_inference_steps, - block_state.strength, - device, + components=components, + num_inference_steps=block_state.num_inference_steps, + strength=block_state.strength, + device=device, denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, @@ -498,14 +498,14 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, - block_state.num_inference_steps, - block_state.device, - block_state.timesteps, - block_state.sigmas, + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + timesteps=block_state.timesteps, + sigmas=block_state.sigmas, ) if ( @@ -581,7 +581,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", ), InputParam( - "mask", + "processed_mask_image", required=True, type_hint=torch.Tensor, description="The mask for the inpainting generation. Can be generated in vae_encode step.", @@ -591,7 +591,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): type_hint=torch.Tensor, description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.", ), - InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."), ] @property diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index 99a677dfe6..ac6ebe78fa 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -57,6 +57,99 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +def get_clip_prompt_embeds( + prompt, + text_encoder, + tokenizer, + device, + clip_skip=None, + max_length=None, +): + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length if max_length is not None else tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only using the pooled output of the text_encoder_2, which has 2 dimensions + # (pooled output for text_encoder has 3 dimensions) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + return prompt_embeds, pooled_prompt_embeds + + +# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components +def encode_vae_image( + image: torch.Tensor, + vae: AutoencoderKL, + generator: torch.Generator, + dtype: torch.dtype, + device: torch.device +): + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + + image = image.to(device=device, dtype=dtype) + + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != image.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {image.shape[0]}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + image_latents = (image_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + image_latents = vae.config.scaling_factor * image_latents + + return image_latents + + class StableDiffusionXLIPAdapterStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -86,6 +179,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config", + required=False, ), ] @@ -103,12 +197,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam( - "negative_ip_adapter_embeds", - type_hint=torch.Tensor, - description="Negative IP adapter image embeddings", - ), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative IP adapter image embeddings"), ] @staticmethod @@ -137,79 +227,36 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): return image_embeds, uncond_image_embeds - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, - components, - ip_adapter_image, - ip_adapter_image_embeds, - device, - num_images_per_prompt, - prepare_unconditional_embeds, - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device + device = components._execution_device - block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - components, - ip_adapter_image=block_state.ip_adapter_image, - ip_adapter_image_embeds=None, - device=block_state.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, - ) - if block_state.prepare_unconditional_embeds: + block_state.ip_adapter_embeds = [] + if components.requires_unconditional_embeds: block_state.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(block_state.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - block_state.negative_ip_adapter_embeds.append(negative_image_embeds) - block_state.ip_adapter_embeds[i] = image_embeds + + if not isinstance(block_state.ip_adapter_image, list): + block_state.ip_adapter_image = [block_state.ip_adapter_image] + + if len(block_state.ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + block_state.ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + block_state.ip_adapter_embeds.append(single_image_embeds[None, :]) + if components.requires_unconditional_embeds: + block_state.negative_ip_adapter_embeds.append(single_negative_image_embeds[None, :]) self.set_block_state(state, block_state) return components, state @@ -225,15 +272,16 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder", CLIPTextModel, required=False), ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer", CLIPTokenizer, required=False), ComponentSpec("tokenizer_2", CLIPTokenizer), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config", + required=False, ), ] @@ -244,7 +292,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): @property def inputs(self) -> List[InputParam]: return [ - InputParam("prompt"), + InputParam("prompt", required=True), InputParam("prompt_2"), InputParam("negative_prompt"), InputParam("negative_prompt_2"), @@ -282,15 +330,25 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): ] @staticmethod - def check_inputs(block_state): - if block_state.prompt is not None and ( - not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + def check_inputs(prompt, prompt_2, negative_prompt, negative_prompt_2): + + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_2 is not None and ( + not isinstance(prompt_2, str) and not isinstance(prompt_2, list) ): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") - elif block_state.prompt_2 is not None and ( - not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list) + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if negative_prompt_2 is not None and ( + not isinstance(negative_prompt_2, str) and not isinstance(negative_prompt_2, list) + ): + raise ValueError(f"`negative_prompt_2` has to be of type `str` or `list` but is {type(negative_prompt_2)}") @staticmethod def encode_prompt( @@ -298,14 +356,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prepare_unconditional_embeds: bool = True, + requires_unconditional_embeds: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): @@ -331,20 +384,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): @@ -352,31 +391,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or components._execution_device + dtype = components.text_encoder_2.dtype - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): - components._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if components.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) - else: - scale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) - else: - scale_lora_layers(components.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] + batch_size = len(prompt) # Define tokenizers and text encoders tokenizers = ( @@ -389,58 +409,56 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): if components.text_encoder is not None else [components.text_encoder_2] ) + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - prompt = components.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] + # dynamically adjust the LoRA scale + for text_encoder in text_encoders: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(text_encoder, lora_scale) else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + scale_lora_layers(text_encoder, lora_scale) + + # Define prompts + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + prompts = [prompt, prompt_2] - prompt_embeds_list.append(prompt_embeds) + # generate prompt_embeds & pooled_prompt_embeds + prompt_embeds_list = [] + pooled_prompt_embeds_list = [] - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + prompt_embeds, pooled_prompt_embeds = get_clip_prompt_embeds( + prompt=prompt, + text_encoder=text_encoder, + tokenizer=tokenizer, + device=device, + clip_skip=clip_skip, + max_length=tokenizer.model_max_length + ) + + prompt_embeds_list.append(prompt_embeds) + if pooled_prompt_embeds.ndim == 2: + pooled_prompt_embeds_list.append(pooled_prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = torch.concat(pooled_prompt_embeds_list, dim=0) + + negative_prompt_embeds = None + negative_pooled_prompt_embeds = None - # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt - if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + # generate negative_prompt_embeds & negative_pooled_prompt_embeds + if requires_unconditional_embeds and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif prepare_unconditional_embeds and negative_prompt_embeds is None: + elif requires_unconditional_embeds: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt @@ -451,87 +469,52 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): ) uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): + if batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] + if batch_size != len(negative_prompt_2): + raise ValueError( + f"`negative_prompt_2`: {negative_prompt_2} has batch size {len(negative_prompt_2)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt_2` matches" + " the batch size of `prompt`." + ) + uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] + negative_pooled_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(components, TextualInversionLoaderMixin): negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", + negative_prompt_embeds, negative_pooled_prompt_embeds = get_clip_prompt_embeds( + prompt=negative_prompt, + text_encoder=text_encoder, + tokenizer=tokenizer, + device=device, + clip_skip=None, + max_length=max_length ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - negative_prompt_embeds_list.append(negative_prompt_embeds) + if negative_pooled_prompt_embeds.ndim == 2: + negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0) - if components.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype, device=device) + pooled_prompt_embeds = pooled_prompt_embeds.to(dtype, device=device) + if requires_unconditional_embeds: + negative_prompt_embeds = negative_prompt_embeds.to(dtype, device=device) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if prepare_unconditional_embeds: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if components.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=components.text_encoder_2.dtype, device=device - ) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if prepare_unconditional_embeds: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if components.text_encoder is not None: + for text_encoder in text_encoders: if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder_2, lora_scale) + unscale_lora_layers(text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -539,13 +522,14 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: # Get inputs and intermediates block_state = self.get_block_state(state) - self.check_inputs(block_state) + + self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2) - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device + device = components._execution_device + dtype = components.text_encoder_2.dtype # Encode input prompt - block_state.text_encoder_lora_scale = ( + lora_scale = ( block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None @@ -557,18 +541,13 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): block_state.negative_pooled_prompt_embeds, ) = self.encode_prompt( components, - block_state.prompt, - block_state.prompt_2, - block_state.device, - 1, - block_state.prepare_unconditional_embeds, - block_state.negative_prompt, - block_state.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=block_state.text_encoder_lora_scale, + prompt=block_state.prompt, + prompt2=block_state.prompt_2, + device = device, + requires_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + negative_prompt_2=block_state.negative_prompt_2, + lora_scale=lora_scale, clip_skip=block_state.clip_skip, ) # Add outputs @@ -599,8 +578,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): def inputs(self) -> List[InputParam]: return [ InputParam("image", required=True), - InputParam("height"), - InputParam("width"), ] @property @@ -608,11 +585,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): return [ InputParam("generator"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam( - "preprocess_kwargs", - type_hint=Optional[dict], - description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", - ), ] @property @@ -622,68 +594,30 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): "image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation", + ), + OutputParam( + "processed_image", + type_hint=PIL.Image.Image, + description="The preprocessed image", ) ] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} - block_state.device = components._execution_device - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.image = components.image_processor.preprocess( - block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs - ) - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + device = components._execution_device + dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.batch_size = block_state.image.shape[0] + block_state.processed_image = components.image_processor.preprocess(block_state.image) - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" - f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." - ) - - block_state.image_latents = self._encode_vae_image( - components, image=block_state.image, generator=block_state.generator + # Encode image into latents + block_state.image_latents = encode_vae_image( + image=block_state.processed_image, + vae=components.vae, + generator=block_state.generator, + dtype=dtype, + device=device ) self.set_block_state(state, block_state) @@ -741,7 +675,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): OutputParam( "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image" ), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), OutputParam( "masked_image_latents", type_hint=torch.Tensor, @@ -752,129 +685,89 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask", ), + OutputParam( + "processed_image", + type_hint=PIL.Image.Image, + description="The preprocessed image", + ), + OutputParam( + "processed_mask_image", + type_hint=torch.Tensor, + description="The preprocessed mask image", + ), ] + + def check_inputs(self, image, mask_image, padding_mask_crop): + + if padding_mask_crop is not None and not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - return mask, masked_image_latents - + if padding_mask_crop is not None and not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device + self.check_inputs(block_state.image, block_state.mask_image, block_state.padding_mask_crop) + + dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + device = components._execution_device if block_state.height is None: - block_state.height = components.default_height + height = components.default_height if block_state.width is None: - block_state.width = components.default_width + width = components.default_width if block_state.padding_mask_crop is not None: - block_state.crops_coords = components.mask_processor.get_crop_region( - block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop + crops_coords = components.mask_processor.get_crop_region( + mask_image=block_state.mask_image, width=width, height=height, pad=block_state.padding_mask_crop ) - block_state.resize_mode = "fill" + resize_mode = "fill" else: - block_state.crops_coords = None - block_state.resize_mode = "default" + crops_coords = None + resize_mode = "default" - block_state.image = components.image_processor.preprocess( + block_state.processed_image = components.image_processor.preprocess( block_state.image, - height=block_state.height, - width=block_state.width, - crops_coords=block_state.crops_coords, - resize_mode=block_state.resize_mode, + height=height, + width=width, + crops_coords=crops_coords, + resize_mode=resize_mode, ) - block_state.image = block_state.image.to(dtype=torch.float32) - block_state.mask = components.mask_processor.preprocess( + block_state.processed_image = block_state.processed_image.to(dtype=torch.float32) + + block_state.processed_mask_image = components.mask_processor.preprocess( block_state.mask_image, - height=block_state.height, - width=block_state.width, - resize_mode=block_state.resize_mode, - crops_coords=block_state.crops_coords, + height=height, + width=width, + resize_mode=resize_mode, + crops_coords=crops_coords, ) - block_state.masked_image = block_state.image * (block_state.mask < 0.5) + + masked_image = block_state.processed_image * (block_state.processed_mask_image < 0.5) - block_state.batch_size = block_state.image.shape[0] - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - block_state.image_latents = self._encode_vae_image( - components, image=block_state.image, generator=block_state.generator + block_state.image_latents = encode_vae_image( + image=block_state.processed_image, + vae=components.vae, + generator=block_state.generator, + dtype=dtype, + device=device ) # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image, - block_state.batch_size, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, + block_state.masked_image_latents = encode_vae_image( + image=masked_image, + vae=components.vae, + generator=block_state.generator, + dtype=dtype, + device=device ) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index fc030fae56..84dd0c0ee3 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -89,6 +89,16 @@ class StableDiffusionXLModularPipeline( if hasattr(self, "vae") and self.vae is not None: num_channels_latents = self.vae.config.latent_channels return num_channels_latents + + @property + def requires_unconditional_embeds(self): + # by default, always prepare unconditional embeddings + requires_unconditional_embeds = True + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider.num_conditions > 1 + + return requires_unconditional_embeds # YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks