diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 5041e352f7..4e3d66f00a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -31,16 +31,13 @@ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMi from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, - deprecate, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxMixin from .pipeline_output import FluxPipelineOutput @@ -146,6 +143,7 @@ def retrieve_timesteps( class FluxPipeline( DiffusionPipeline, + FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, @@ -215,178 +213,6 @@ class FluxPipeline( ) self.default_sample_size = 128 - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -503,97 +329,6 @@ class FluxPipeline( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." - deprecate( - "enable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." - deprecate( - "disable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." - deprecate( - "enable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." - deprecate( - "disable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.disable_tiling() - def prepare_latents( self, batch_size, @@ -629,26 +364,6 @@ class FluxPipeline( return latents, latent_image_ids - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def current_timestep(self): - return self._current_timestep - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 848d7bd392..0fcdf159b6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -25,16 +25,13 @@ from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, - deprecate, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxControlMixin from .pipeline_output import FluxPipelineOutput @@ -159,6 +156,7 @@ def retrieve_timesteps( class FluxControlPipeline( DiffusionPipeline, + FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, @@ -226,181 +224,6 @@ class FluxControlPipeline( ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - def check_inputs( self, prompt, @@ -451,100 +274,6 @@ class FluxControlPipeline( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." - deprecate( - "enable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." - deprecate( - "disable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." - deprecate( - "enable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." - deprecate( - "disable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.disable_tiling() - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, @@ -581,57 +310,6 @@ class FluxControlPipeline( return latents, latent_image_ids - # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index 262345c75a..b94bfde27f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -20,20 +20,18 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxControlMixin from .pipeline_output import FluxPipelineOutput @@ -175,7 +173,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin): r""" The Flux pipeline for image inpainting. @@ -236,181 +234,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -492,47 +315,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - def prepare_latents( self, image, @@ -615,22 +397,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin return image - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index 6915a83a7c..d9c3f9aca1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -34,16 +34,13 @@ from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, - deprecate, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxControlMixin from .pipeline_output import FluxPipelineOutput @@ -204,6 +201,7 @@ def retrieve_timesteps( class FluxControlInpaintPipeline( DiffusionPipeline, + FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, @@ -276,181 +274,6 @@ class FluxControlInpaintPipeline( ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -532,100 +355,6 @@ class FluxControlInpaintPipeline( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." - deprecate( - "enable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." - deprecate( - "disable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." - deprecate( - "enable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." - deprecate( - "disable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.disable_tiling() - def prepare_latents( self, image, @@ -786,22 +515,6 @@ class FluxControlInpaintPipeline( return mask_image, masked_image_latents - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 507ec68734..991e0aeb96 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -27,21 +27,19 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxControlMixin from .pipeline_output import FluxPipelineOutput @@ -174,7 +172,9 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): +class FluxControlNetPipeline( + DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin +): r""" The Flux pipeline for text-to-image generation. @@ -245,181 +245,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ) self.default_sample_size = 128 - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -546,47 +371,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, @@ -658,22 +442,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF return image - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 582c7bbad8..710f000da2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -11,21 +11,19 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxControlMixin from .pipeline_output import FluxPipelineOutput @@ -169,7 +167,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin): r""" The Flux controlnet pipeline for image-to-image generation. @@ -236,181 +234,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -492,47 +315,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - def prepare_latents( self, image, @@ -615,22 +397,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From return image - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index f7f34ef231..a3971e7dbd 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -12,21 +12,19 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxControlMixin from .pipeline_output import FluxPipelineOutput @@ -171,7 +169,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin): r""" The Flux controlnet pipeline for inpainting. @@ -247,181 +245,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -520,47 +343,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - def prepare_latents( self, image, @@ -719,22 +501,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From return image - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 5cb9c82204..2c3a5c9e4a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -25,16 +25,13 @@ from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, - deprecate, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxMixin from .pipeline_output import FluxPipelineOutput @@ -167,6 +164,7 @@ def retrieve_latents( class FluxFillPipeline( DiffusionPipeline, + FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, @@ -241,101 +239,6 @@ class FluxFillPipeline( ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - def prepare_mask_latents( self, mask, @@ -416,86 +319,6 @@ class FluxFillPipeline( return mask, masked_image_latents - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -588,100 +411,6 @@ class FluxFillPipeline( if image is not None and mask_image is None: raise ValueError("Please provide `mask_image` when passing `image`.") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." - deprecate( - "enable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." - deprecate( - "disable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." - deprecate( - "enable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." - deprecate( - "disable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.disable_tiling() - # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents def prepare_latents( self, @@ -733,22 +462,6 @@ class FluxFillPipeline( latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index ab9140dae9..e9ca3e1a44 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -27,21 +27,19 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, deprecate, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxMixin from .pipeline_output import FluxPipelineOutput @@ -167,7 +165,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): +class FluxImg2ImgPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): r""" The Flux pipeline for image inpainting. @@ -235,181 +233,6 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -567,47 +390,6 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" @@ -715,22 +497,6 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 3bfe82cf43..6a844a5051 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -28,20 +28,18 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxMixin from .pipeline_output import FluxPipelineOutput @@ -163,7 +161,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterMixin): +class FluxInpaintPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FluxIPAdapterMixin): r""" The Flux pipeline for image inpainting. @@ -238,181 +236,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -587,47 +410,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - def prepare_latents( self, image, @@ -756,22 +538,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM return mask, masked_image_latents - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index 94ae460afc..10877bd1ef 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -31,16 +31,14 @@ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMi from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, deprecate, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxMixin from .pipeline_output import FluxPipelineOutput @@ -190,6 +188,7 @@ def retrieve_latents( class FluxKontextPipeline( DiffusionPipeline, + FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, @@ -260,181 +259,6 @@ class FluxKontextPipeline( ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -554,47 +378,6 @@ class FluxKontextPipeline( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ @@ -728,22 +511,6 @@ class FluxKontextPipeline( return latents, image_latents, latent_ids, image_ids - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def current_timestep(self): - return self._current_timestep - @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py index b6f957981e..7101dcc9af 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -21,16 +21,14 @@ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMi from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, deprecate, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pipeline_flux_utils import FluxMixin from .pipeline_output import FluxPipelineOutput @@ -214,6 +212,7 @@ def retrieve_latents( class FluxKontextInpaintPipeline( DiffusionPipeline, + FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, @@ -293,181 +292,6 @@ class FluxKontextInpaintPipeline( ) self.default_sample_size = 128 - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -628,47 +452,6 @@ class FluxKontextInpaintPipeline( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ @@ -917,22 +700,6 @@ class FluxKontextInpaintPipeline( return mask, masked_image_latents - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def current_timestep(self): - return self._current_timestep - @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index e79db337b2..5a00edb6b1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -27,17 +27,14 @@ from transformers import ( ) from ...image_processor import PipelineImageInput -from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ..pipeline_utils import DiffusionPipeline from .modeling_flux import ReduxImageEncoder +from .pipeline_flux_utils import FluxMixin from .pipeline_output import FluxPriorReduxPipelineOutput @@ -81,7 +78,7 @@ EXAMPLE_DOC_STRING = """ """ -class FluxPriorReduxPipeline(DiffusionPipeline): +class FluxPriorReduxPipeline(DiffusionPipeline, FluxMixin): r""" The Flux Redux pipeline for image-to-image generation. @@ -193,181 +190,6 @@ class FluxPriorReduxPipeline(DiffusionPipeline): return image_enc_hidden_states - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Optional[Union[str, List[str]]] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__(