From 36159dd2a6971ef72416ea87a299359f82df1c96 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 29 May 2025 04:34:31 +0200 Subject: [PATCH] add tests --- .../pipelines/flux/pipeline_flux_kontext.py | 16 +- .../flux/test_pipeline_flux_kontext.py | 177 ++++++++++++++++++ 2 files changed, 187 insertions(+), 6 deletions(-) create mode 100644 tests/pipelines/flux/test_pipeline_flux_kontext.py diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index 65102dea2d..bc8d05cd68 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -740,6 +740,7 @@ class FluxKontextPipeline( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, max_area: int = 1024**2, + _auto_resize: bool = True, ): r""" Function invoked when calling the pipeline for generation. @@ -937,13 +938,16 @@ class FluxKontextPipeline( # 3. Preprocess image if not torch.is_tensor(image) or image.size(1) == self.latent_channels: - image_width, image_height = self.image_processor.get_default_height_width(image) + if isinstance(image, list): + image_width, image_height = self.image_processor.get_default_height_width(image[0]) + else: + image_width, image_height = self.image_processor.get_default_height_width(image) aspect_ratio = image_width / image_height - - # Kontext is trained on specific resolutions, using one of them is recommended - _, image_width, image_height = min( - (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS - ) + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) image_width = image_width // multiple_of * multiple_of image_height = image_height // multiple_of * multiple_of image = self.image_processor.resize(image, image_height, image_width) diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext.py b/tests/pipelines/flux/test_pipeline_flux_kontext.py new file mode 100644 index 0000000000..7471d78ad5 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_kontext.py @@ -0,0 +1,177 @@ +import unittest + +import numpy as np +import PIL.Image +import torch +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FasterCacheConfig, + FlowMatchEulerDiscreteScheduler, + FluxKontextPipeline, + FluxTransformer2DModel, +) +from diffusers.utils.testing_utils import torch_device + +from ..test_pipelines_common import ( + FasterCacheTesterMixin, + FluxIPAdapterTesterMixin, + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, +) + + +class FluxKontextPipelineFastTests( + unittest.TestCase, + PipelineTesterMixin, + FluxIPAdapterTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, +): + pipeline_class = FluxKontextPipeline + params = frozenset( + ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"] + ) + batch_params = frozenset(["image", "prompt"]) + + # there is no xformers processor for Flux + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + is_guidance_distilled=True, + ) + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=4, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + "image_encoder": None, + "feature_extractor": None, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = PIL.Image.new("RGB", (32, 32), 0) + inputs = { + "image": image, + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_area": 8 * 8, + "max_sequence_length": 48, + "output_type": "np", + "_auto_resize": False, + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width, "max_area": height * width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_flux_true_cfg(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("generator") + + no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + inputs["negative_prompt"] = "bad quality" + inputs["true_cfg_scale"] = 2.0 + true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + assert not np.allclose(no_true_cfg_out, true_cfg_out)