mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
attention_kwargs -> cross_attention_kwargs.
This commit is contained in:
@@ -364,12 +364,12 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
if cross_attention_kwargs is not None:
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
@@ -377,9 +377,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
"Passing `scale` via `cross_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
|
||||
@@ -574,8 +574,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
@@ -613,7 +613,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
return_dict: bool = True,
|
||||
clean_caption: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 300,
|
||||
@@ -686,7 +686,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs: TODO
|
||||
cross_attention_kwargs: TODO
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
@@ -747,7 +747,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
@@ -759,7 +759,9 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
||||
lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
@@ -829,7 +831,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user