mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
handle lora scale and clip skip in lpw sd and sdxl community pipelines (#8988)
* handle lora scale and clip skip in lpw sd and sdxl * use StableDiffusionLoraLoaderMixin * use StableDiffusionXLLoraLoaderMixin * style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -13,13 +13,17 @@ from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings(
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
||||
if clip_skip is None:
|
||||
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
|
||||
text_embedding = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings(
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||
if clip_skip is None:
|
||||
clip_skip = 0
|
||||
prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]
|
||||
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
return text_embeddings
|
||||
|
||||
|
||||
@@ -242,6 +263,8 @@ def get_weighted_text_embeddings(
|
||||
no_boseos_middle: Optional[bool] = False,
|
||||
skip_parsing: Optional[bool] = False,
|
||||
skip_weighting: Optional[bool] = False,
|
||||
clip_skip=None,
|
||||
lora_scale=None,
|
||||
):
|
||||
r"""
|
||||
Prompts can be assigned with local weights using brackets. For example,
|
||||
@@ -268,6 +291,16 @@ def get_weighted_text_embeddings(
|
||||
skip_weighting (`bool`, *optional*, defaults to `False`):
|
||||
Skip the weighting. When the parsing is skipped, it is forced True.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):
|
||||
pipe._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(pipe.text_encoder, lora_scale)
|
||||
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
@@ -334,10 +367,7 @@ def get_weighted_text_embeddings(
|
||||
|
||||
# get the embeddings
|
||||
text_embeddings = get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
prompt_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
|
||||
)
|
||||
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
|
||||
if uncond_prompt is not None:
|
||||
@@ -346,6 +376,7 @@ def get_weighted_text_embeddings(
|
||||
uncond_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
|
||||
|
||||
@@ -362,6 +393,11 @@ def get_weighted_text_embeddings(
|
||||
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if pipe.text_encoder is not None:
|
||||
if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(pipe.text_encoder, lora_scale)
|
||||
|
||||
if uncond_prompt is not None:
|
||||
return text_embeddings, uncond_embeddings
|
||||
return text_embeddings, None
|
||||
@@ -549,6 +585,8 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
max_embeddings_multiples=3,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
@@ -597,6 +635,8 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
prompt=prompt,
|
||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
clip_skip=clip_skip,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = prompt_embeds1
|
||||
@@ -790,6 +830,7 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
@@ -865,6 +906,9 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
is_cancelled_callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. If the function returns
|
||||
`True`, the inference will be cancelled.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
@@ -903,6 +947,7 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
@@ -914,6 +959,8 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
max_embeddings_multiples,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clip_skip=clip_skip,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
dtype = prompt_embeds.dtype
|
||||
|
||||
@@ -1044,6 +1091,7 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
clip_skip=None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
@@ -1101,6 +1149,9 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
is_cancelled_callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. If the function returns
|
||||
`True`, the inference will be cancelled.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
@@ -1135,6 +1186,7 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
is_cancelled_callback=is_cancelled_callback,
|
||||
clip_skip=clip_skip,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -25,21 +25,25 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import (
|
||||
FromSingleFileMixin,
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
@@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl(
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
lora_scale: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
This function can process long prompt with weights, no length limitation
|
||||
@@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl(
|
||||
"""
|
||||
device = device or pipe._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
|
||||
pipe._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if pipe.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(pipe.text_encoder, lora_scale)
|
||||
|
||||
if pipe.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(pipe.text_encoder_2, lora_scale)
|
||||
|
||||
if prompt_2:
|
||||
prompt = f"{prompt} {prompt_2}"
|
||||
|
||||
@@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if pipe.text_encoder is not None:
|
||||
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(pipe.text_encoder, lora_scale)
|
||||
|
||||
if pipe.text_encoder_2 is not None:
|
||||
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(pipe.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
|
||||
@@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
StableDiffusionMixin,
|
||||
FromSingleFileMixin,
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
@@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
|
||||
Args:
|
||||
@@ -743,7 +776,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
||||
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1612,7 +1645,9 @@ class SDXLLongPromptWeightingPipeline(
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 3. Encode input prompt
|
||||
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
|
||||
lora_scale = (
|
||||
self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
||||
|
||||
@@ -1627,6 +1662,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
neg_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
dtype = prompt_embeds.dtype
|
||||
|
||||
|
||||
Reference in New Issue
Block a user