mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add attention masking.
This commit is contained in:
@@ -344,6 +344,10 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def is_chroma(self) -> bool:
|
||||
return isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
@@ -500,7 +504,7 @@ class FluxTransformer2DModel(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
is_chroma = isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)
|
||||
is_chroma = self.is_chroma
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
|
||||
@@ -191,6 +191,7 @@ class FluxPipeline(
|
||||
transformer: FluxTransformer2DModel,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
variant: str = "flux",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -213,6 +214,17 @@ class FluxPipeline(
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 128
|
||||
if variant not in {"flux", "chroma"}:
|
||||
raise ValueError("`variant` must be `'flux' or `'chroma'`.")
|
||||
|
||||
self.variant = variant
|
||||
|
||||
def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor:
|
||||
attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device)
|
||||
for i, n_tokens in enumerate(length):
|
||||
n_tokens = torch.max(n_tokens + 1, max_sequence_length)
|
||||
attention_mask[i, :n_tokens] = True
|
||||
return attention_mask
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
@@ -236,7 +248,7 @@ class FluxPipeline(
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_length=(self.variant == "chroma"),
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -250,7 +262,15 @@ class FluxPipeline(
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
||||
prompt_embeds = self.text_encoder_2(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=False,
|
||||
attention_mask=(
|
||||
self._get_chroma_attn_mask(text_inputs.length, max_sequence_length).to(device)
|
||||
if self.variant == "chroma"
|
||||
else None
|
||||
),
|
||||
)[0]
|
||||
|
||||
dtype = self.text_encoder_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
Reference in New Issue
Block a user