From 565d674cc42dcf4b51b48886e85b615e93ef9b4b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 16:30:58 +0530 Subject: [PATCH] change flux lora integration tests to use pytest --- tests/lora/test_lora_layers_flux.py | 199 +++++++++++++++++++++------- tests/testing_utils.py | 2 +- 2 files changed, 152 insertions(+), 49 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 8b9e6ec472..b7518d701a 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -16,7 +16,6 @@ import copy import gc import os import sys -import unittest import numpy as np import pytest @@ -26,7 +25,7 @@ from PIL import Image from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel -from diffusers.utils import logging +from diffusers.utils import load_image, logging from ..testing_utils import ( CaptureLogger, @@ -752,7 +751,7 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests): @require_torch_accelerator @require_peft_backend @require_big_accelerator -class FluxLoRAIntegrationTests(unittest.TestCase): +class TestFluxLoRAIntegration: """internal note: The integration slices were obtained on audace. torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the @@ -762,25 +761,25 @@ class FluxLoRAIntegrationTests(unittest.TestCase): num_inference_steps = 10 seed = 0 - def setUp(self): - super().setUp() + @pytest.fixture(scope="function") + def pipeline(self, torch_device): gc.collect() backend_empty_cache(torch_device) - self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + try: + yield pipe + finally: + del pipe + gc.collect() + backend_empty_cache(torch_device) - def tearDown(self): - super().tearDown() - del self.pipeline - gc.collect() - backend_empty_cache(torch_device) - - def test_flux_the_last_ben(self): - self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) + def test_flux_the_last_ben(self, pipeline): + pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "jon snow eating pizza with ketchup" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.0, @@ -792,13 +791,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - def test_flux_kohya(self): - self.pipeline.load_lora_weights("Norod78/brain-slug-flux") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) + def test_flux_kohya(self, pipeline): + pipeline.load_lora_weights("Norod78/brain-slug-flux") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "The cat with a brain slug earring" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.5, @@ -810,13 +809,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - def test_flux_kohya_with_text_encoder(self): - self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) + def test_flux_kohya_with_text_encoder(self, pipeline): + pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "optimus is cleaning the house with broomstick" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.5, @@ -828,19 +827,18 @@ class FluxLoRAIntegrationTests(unittest.TestCase): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - def test_flux_kohya_embedders_conversion(self): + def test_flux_kohya_embedders_conversion(self, pipeline): """Test that embedders load without throwing errors""" - self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") - self.pipeline.unload_lora_weights() - assert True + pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") + pipeline.unload_lora_weights() - def test_flux_xlabs(self): - self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) + def test_flux_xlabs(self, pipeline): + pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=3.5, @@ -852,15 +850,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 - def test_flux_xlabs_load_lora_with_single_blocks(self): - self.pipeline.load_lora_weights( - "salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors" - ) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline.enable_model_cpu_offload() + def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline): + pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline.enable_model_cpu_offload() prompt = "a wizard mouse playing chess" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=3.5, @@ -873,3 +869,110 @@ class FluxLoRAIntegrationTests(unittest.TestCase): ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 0.001 + + +@nightly +@require_torch_accelerator +@require_peft_backend +@require_big_accelerator +class TestFluxControlLoRAIntegration: + num_inference_steps = 10 + seed = 0 + prompt = "A robot made of exotic candies and chocolates of different kinds." + + @pytest.fixture(scope="function") + def pipeline(self, torch_device): + gc.collect() + backend_empty_cache(torch_device) + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + try: + yield pipe + finally: + del pipe + gc.collect() + backend_empty_cache(torch_device) + + @pytest.mark.parametrize( + "lora_ckpt_id", + [ + "black-forest-labs/FLUX.1-Canny-dev-lora", + "black-forest-labs/FLUX.1-Depth-dev-lora", + ], + ) + def test_lora(self, pipeline, lora_ckpt_id): + pipeline.load_lora_weights(lora_ckpt_id) + pipeline.fuse_lora() + pipeline.unload_lora_weights() + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = pipeline( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=self.num_inference_steps, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516]) + else: + expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 + + @pytest.mark.parametrize( + "lora_ckpt_id", + [ + "black-forest-labs/FLUX.1-Canny-dev-lora", + "black-forest-labs/FLUX.1-Depth-dev-lora", + ], + ) + def test_lora_with_turbo(self, pipeline, lora_ckpt_id): + pipeline.load_lora_weights(lora_ckpt_id) + pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + + if "Canny" in lora_ckpt_id: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png" + ) + else: + control_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" + ) + + image = self.pipeline( + prompt=self.prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=self.num_inference_steps, + guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = image[0, -3:, -3:, -1].flatten() + if "Canny" in lora_ckpt_id: + expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484]) + else: + expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 73d6045915..988834acf5 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -266,7 +266,7 @@ def slow(test_case): Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ - return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case) def nightly(test_case):