diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 50c0c4cedc..f6d2e366e4 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -40,7 +40,7 @@ from ...utils import ( ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import FluxPipelineOutput +from .pipeline_output import ChromaPipelineOutput if is_torch_xla_available(): @@ -57,15 +57,13 @@ EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch - >>> from diffusers import FluxPipeline + >>> from diffusers import ChromaPipeline - >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = ChromaPipeline.from_single_file("chroma-unlocked-v35-detail-calibrated.safetensors", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" - >>> # Depending on the variant being used, the pipeline call will slightly vary. - >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] + >>> image.save("chroma.png") ``` """ @@ -143,7 +141,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxPipeline( +class ChromaPipeline( DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, @@ -151,27 +149,21 @@ class FluxPipeline( FluxIPAdapterMixin, ): r""" - The Flux pipeline for text-to-image generation. + The Chroma pipeline for text-to-image generation. - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + Reference: https://huggingface.co/lodestones/Chroma/ Args: - transformer ([`FluxTransformer2DModel`]): + transformer ([`ChromaTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation + text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): + tokenizer (`T5TokenizerFast`): Second Tokenizer of class [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ @@ -184,11 +176,9 @@ class FluxPipeline( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5TokenizerFast, - transformer: FluxTransformer2DModel, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: ChromaTransformer2DModel, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, variant: str = "flux", @@ -198,9 +188,7 @@ class FluxPipeline( self.register_modules( vae=vae, text_encoder=text_encoder, - text_encoder_2=text_encoder_2, tokenizer=tokenizer, - tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, image_encoder=image_encoder, @@ -214,10 +202,6 @@ class FluxPipeline( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = 128 - if variant not in {"flux", "chroma"}: - raise ValueError("`variant` must be `'flux' or `'chroma'`.") - - self.variant = variant def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor: attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device) @@ -248,7 +232,7 @@ class FluxPipeline( padding="max_length", max_length=max_sequence_length, truncation=True, - return_length=(self.variant == "chroma"), + return_length=True, return_overflowing_tokens=False, return_tensors="pt", ) @@ -267,8 +251,6 @@ class FluxPipeline( output_hidden_states=False, attention_mask=( self._get_chroma_attn_mask(text_inputs.length, max_sequence_length).to(device) - if self.variant == "chroma" - else None ), )[0] @@ -283,58 +265,12 @@ class FluxPipeline( return prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - def encode_prompt( self, prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): @@ -343,9 +279,6 @@ class FluxPipeline( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): @@ -369,21 +302,11 @@ class FluxPipeline( # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, @@ -396,15 +319,10 @@ class FluxPipeline( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - return prompt_embeds, pooled_prompt_embeds, text_ids + return prompt_embeds, text_ids def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -456,15 +374,12 @@ class FluxPipeline( def check_inputs( self, prompt, - prompt_2, height, width, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -485,39 +400,18 @@ class FluxPipeline( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @@ -649,10 +543,7 @@ class FluxPipeline( def __call__( self, prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt: Union[str, List[str]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, @@ -662,13 +553,11 @@ class FluxPipeline( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -683,18 +572,10 @@ class FluxPipeline( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is not greater than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - true_cfg_scale (`float`, *optional*, defaults to 1.0): - When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -724,9 +605,6 @@ class FluxPipeline( prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of @@ -742,10 +620,6 @@ class FluxPipeline( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -769,7 +643,7 @@ class FluxPipeline( Examples: Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ @@ -780,15 +654,11 @@ class FluxPipeline( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, - prompt_2, height, width, negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -811,34 +681,25 @@ class FluxPipeline( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - has_neg_prompt = negative_prompt is not None or ( - negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None - ) - do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + do_cfg = guidance_scale > 1 ( prompt_embeds, - pooled_prompt_embeds, text_ids, ) = self.encode_prompt( prompt=prompt, - prompt_2=prompt_2, prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) - if do_true_cfg: + if do_cfg: ( negative_prompt_embeds, - negative_pooled_prompt_embeds, negative_text_ids, ) = self.encode_prompt( prompt=negative_prompt, - prompt_2=negative_prompt_2, prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, @@ -933,7 +794,6 @@ class FluxPipeline( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, @@ -941,7 +801,7 @@ class FluxPipeline( return_dict=False, )[0] - if do_true_cfg: + if do_cfg: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds neg_noise_pred = self.transformer(