diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index d772e89139..21bf42b57d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -344,6 +344,10 @@ class FluxTransformer2DModel( self.gradient_checkpointing = False + @property + def is_chroma(self) -> bool: + return isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings) + @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -500,7 +504,7 @@ class FluxTransformer2DModel( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - is_chroma = isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings) + is_chroma = self.is_chroma hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index a7266c3a56..50c0c4cedc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -191,6 +191,7 @@ class FluxPipeline( transformer: FluxTransformer2DModel, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, + variant: str = "flux", ): super().__init__() @@ -213,6 +214,17 @@ class FluxPipeline( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = 128 + if variant not in {"flux", "chroma"}: + raise ValueError("`variant` must be `'flux' or `'chroma'`.") + + self.variant = variant + + def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor: + attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device) + for i, n_tokens in enumerate(length): + n_tokens = torch.max(n_tokens + 1, max_sequence_length) + attention_mask[i, :n_tokens] = True + return attention_mask def _get_t5_prompt_embeds( self, @@ -236,7 +248,7 @@ class FluxPipeline( padding="max_length", max_length=max_sequence_length, truncation=True, - return_length=False, + return_length=(self.variant == "chroma"), return_overflowing_tokens=False, return_tensors="pt", ) @@ -250,7 +262,15 @@ class FluxPipeline( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + prompt_embeds = self.text_encoder_2( + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=( + self._get_chroma_attn_mask(text_inputs.length, max_sequence_length).to(device) + if self.variant == "chroma" + else None + ), + )[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)