diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 41224e42d2..14ec557b8e 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -364,12 +364,12 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): timestep: torch.LongTensor, encoder_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - attention_kwargs: Optional[Dict[str, Any]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 @@ -377,9 +377,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `cross_attention_kwargs` when not using the PEFT backend is ineffective." ) # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index b9d747fe5e..ea6480d35a 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -574,8 +574,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): return self._guidance_scale @property - def attention_kwargs(self): - return self._attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs @property def do_classifier_free_guidance(self): @@ -613,7 +613,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): return_dict: bool = True, clean_caption: bool = True, use_resolution_binning: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 300, @@ -686,7 +686,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. - attention_kwargs: TODO + cross_attention_kwargs: TODO clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw @@ -747,7 +747,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): ) self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs + self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Default height and width to transformer @@ -759,7 +759,9 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): batch_size = prompt_embeds.shape[0] device = self._execution_device - lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) # 3. Encode input prompt ( @@ -829,7 +831,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): encoder_attention_mask=prompt_attention_mask, timestep=timestep, return_dict=False, - attention_kwargs=self.attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, )[0] noise_pred = noise_pred.float()