From 73b59f5203b5df71175dfd71f613b9bd380b4531 Mon Sep 17 00:00:00 2001 From: Ina <1224084650@qq.com> Date: Sat, 26 Oct 2024 05:01:51 +0800 Subject: [PATCH] [refactor] enhance readability of flux related pipelines (#9711) * flux pipline: readability enhancement. --- .../train_dreambooth_lora_flux_advanced.py | 8 ++--- examples/controlnet/train_controlnet_flux.py | 4 +-- examples/dreambooth/train_dreambooth_flux.py | 10 +++--- .../dreambooth/train_dreambooth_lora_flux.py | 10 +++--- src/diffusers/pipelines/flux/pipeline_flux.py | 26 +++++++------- .../flux/pipeline_flux_controlnet.py | 26 +++++++------- ...pipeline_flux_controlnet_image_to_image.py | 28 ++++++++------- .../pipeline_flux_controlnet_inpainting.py | 34 +++++++++++-------- .../pipelines/flux/pipeline_flux_img2img.py | 28 ++++++++------- .../pipelines/flux/pipeline_flux_inpaint.py | 32 +++++++++-------- 10 files changed, 110 insertions(+), 96 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index e3e46ead8e..ccc390ab7b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2198,8 +2198,8 @@ def main(args): latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], - model_input.shape[2], - model_input.shape[3], + model_input.shape[2] // 2, + model_input.shape[3] // 2, accelerator.device, weight_dtype, ) @@ -2253,8 +2253,8 @@ def main(args): )[0] model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2] * vae_scale_factor / 2), - width=int(model_input.shape[3] * vae_scale_factor / 2), + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, vae_scale_factor=vae_scale_factor, ) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index ca822b16ea..2958a9e5f2 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -1256,8 +1256,8 @@ def main(args): latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids( batch_size=pixel_latents_tmp.shape[0], - height=pixel_latents_tmp.shape[2], - width=pixel_latents_tmp.shape[3], + height=pixel_latents_tmp.shape[2] // 2, + width=pixel_latents_tmp.shape[3] // 2, device=pixel_values.device, dtype=pixel_values.dtype, ) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index db4788281c..add266d3ac 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1540,12 +1540,12 @@ def main(args): model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], - model_input.shape[2], - model_input.shape[3], + model_input.shape[2] // 2, + model_input.shape[3] // 2, accelerator.device, weight_dtype, ) @@ -1601,8 +1601,8 @@ def main(args): # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042 model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2] * vae_scale_factor / 2), - width=int(model_input.shape[3] * vae_scale_factor / 2), + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, vae_scale_factor=vae_scale_factor, ) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index b09e5b38b2..fa4db10f4f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1645,12 +1645,12 @@ def main(args): model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], - model_input.shape[2], - model_input.shape[3], + model_input.shape[2] // 2, + model_input.shape[3] // 2, accelerator.device, weight_dtype, ) @@ -1704,8 +1704,8 @@ def main(args): )[0] model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2] * vae_scale_factor / 2), - width=int(model_input.shape[3] * vae_scale_factor / 2), + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, vae_scale_factor=vae_scale_factor, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 8278365e94..040d935f1b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -195,13 +195,13 @@ class FluxPipeline( scheduler=scheduler, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 def _get_t5_prompt_embeds( self, @@ -386,8 +386,10 @@ class FluxPipeline( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -425,9 +427,9 @@ class FluxPipeline( @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -452,10 +454,10 @@ class FluxPipeline( height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -499,8 +501,8 @@ class FluxPipeline( generator, latents=None, ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) @@ -517,7 +519,7 @@ class FluxPipeline( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents, latent_image_ids diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 5136c42001..9f33e26013 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -216,13 +216,13 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF controlnet=controlnet, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 def _get_t5_prompt_embeds( self, @@ -410,8 +410,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -450,9 +452,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -479,10 +481,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -498,8 +500,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF generator, latents=None, ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) @@ -516,7 +518,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents, latent_image_ids diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 8d636feeae..810c970ab7 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -228,13 +228,13 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From controlnet=controlnet, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -453,8 +453,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -493,9 +495,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -522,10 +524,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -549,11 +551,11 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) if latents is not None: return latents.to(device=device, dtype=dtype), latent_image_ids @@ -852,7 +854,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From control_mode = control_mode.reshape([-1, 1]) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 46784f2d46..3ca2de633f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -231,7 +231,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( @@ -244,7 +244,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -467,8 +467,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -520,9 +522,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -549,10 +551,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -576,11 +578,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) @@ -622,8 +624,8 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From device, generator, ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision @@ -996,7 +998,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From # 6. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor) + image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( + int(global_width) // self.vae_scale_factor // 2 + ) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 112260003e..47f9f268ee 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -212,13 +212,13 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): scheduler=scheduler, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -437,8 +437,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -477,9 +479,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -506,10 +508,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -532,11 +534,11 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) if latents is not None: return latents.to(device=device, dtype=dtype), latent_image_ids @@ -736,7 +738,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): # 4.Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index ae348c0f64..766f986483 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -209,7 +209,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): scheduler=scheduler, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( @@ -222,7 +222,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -445,8 +445,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -498,9 +500,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -527,10 +529,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -553,11 +555,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) @@ -598,8 +600,8 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): device, generator, ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision @@ -866,7 +868,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): # 4.Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len,