From 71574e9dece4d54c559cf825ec5dd890100e1168 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Dec 2024 11:11:11 +0530 Subject: [PATCH] make fix-copies --- .../pipelines/pag/pipeline_pag_sana.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 081dbef21e..7ee3f0a4db 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -23,15 +23,19 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin from ...models import AutoencoderDC, SanaTransformer2DModel from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0 from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, + USE_PEFT_BACKEND, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -185,6 +189,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): clean_caption: bool = False, max_sequence_length: int = 300, complex_human_instruction: Optional[List[str]] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -218,6 +223,15 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): if device is None: device = self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -313,6 +327,11 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): negative_prompt_embeds = None negative_prompt_attention_mask = None + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs