1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

attention_kwargs -> cross_attention_kwargs.

This commit is contained in:
sayakpaul
2024-12-16 11:40:12 +05:30
parent f219198ce9
commit 23433bf9bc
2 changed files with 15 additions and 13 deletions

View File

@@ -364,12 +364,12 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
timestep: torch.LongTensor,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
@@ -377,9 +377,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
"Passing `scale` via `cross_attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.

View File

@@ -574,8 +574,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
return self._guidance_scale
@property
def attention_kwargs(self):
return self._attention_kwargs
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def do_classifier_free_guidance(self):
@@ -613,7 +613,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
return_dict: bool = True,
clean_caption: bool = True,
use_resolution_binning: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 300,
@@ -686,7 +686,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
attention_kwargs: TODO
cross_attention_kwargs: TODO
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
@@ -747,7 +747,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Default height and width to transformer
@@ -759,7 +759,9 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
# 3. Encode input prompt
(
@@ -829,7 +831,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
cross_attention_kwargs=self.cross_attention_kwargs,
)[0]
noise_pred = noise_pred.float()