1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

change flux lora integration tests to use pytest

This commit is contained in:
sayakpaul
2025-10-03 16:30:58 +05:30
parent 610842af1a
commit 565d674cc4
2 changed files with 152 additions and 49 deletions

View File

@@ -16,7 +16,6 @@ import copy
import gc
import os
import sys
import unittest
import numpy as np
import pytest
@@ -26,7 +25,7 @@ from PIL import Image
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel
from diffusers.utils import logging
from diffusers.utils import load_image, logging
from ..testing_utils import (
CaptureLogger,
@@ -752,7 +751,7 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
class TestFluxLoRAIntegration:
"""internal note: The integration slices were obtained on audace.
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
@@ -762,25 +761,25 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
def setUp(self):
super().setUp()
@pytest.fixture(scope="function")
def pipeline(self, torch_device):
gc.collect()
backend_empty_cache(torch_device)
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
del self.pipeline
gc.collect()
backend_empty_cache(torch_device)
def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_the_last_ben(self, pipeline):
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "jon snow eating pizza with ketchup"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.0,
@@ -792,13 +791,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_kohya(self, pipeline):
pipeline.load_lora_weights("Norod78/brain-slug-flux")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "The cat with a brain slug earring"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
@@ -810,13 +809,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_kohya_with_text_encoder(self, pipeline):
pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "optimus is cleaning the house with broomstick"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
@@ -828,19 +827,18 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
def test_flux_kohya_embedders_conversion(self):
def test_flux_kohya_embedders_conversion(self, pipeline):
"""Test that embedders load without throwing errors"""
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
self.pipeline.unload_lora_weights()
assert True
pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
pipeline.unload_lora_weights()
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_xlabs(self, pipeline):
pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
@@ -852,15 +850,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
def test_flux_xlabs_load_lora_with_single_blocks(self):
self.pipeline.load_lora_weights(
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()
def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline):
pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline.enable_model_cpu_offload()
prompt = "a wizard mouse playing chess"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
@@ -873,3 +869,110 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
@nightly
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class TestFluxControlLoRAIntegration:
num_inference_steps = 10
seed = 0
prompt = "A robot made of exotic candies and chocolates of different kinds."
@pytest.fixture(scope="function")
def pipeline(self, torch_device):
gc.collect()
backend_empty_cache(torch_device)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
@pytest.mark.parametrize(
"lora_ckpt_id",
[
"black-forest-labs/FLUX.1-Canny-dev-lora",
"black-forest-labs/FLUX.1-Depth-dev-lora",
],
)
def test_lora(self, pipeline, lora_ckpt_id):
pipeline.load_lora_weights(lora_ckpt_id)
pipeline.fuse_lora()
pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id:
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
)
else:
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
)
image = pipeline(
prompt=self.prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=self.num_inference_steps,
guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = image[0, -3:, -3:, -1].flatten()
if "Canny" in lora_ckpt_id:
expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516])
else:
expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
@pytest.mark.parametrize(
"lora_ckpt_id",
[
"black-forest-labs/FLUX.1-Canny-dev-lora",
"black-forest-labs/FLUX.1-Depth-dev-lora",
],
)
def test_lora_with_turbo(self, pipeline, lora_ckpt_id):
pipeline.load_lora_weights(lora_ckpt_id)
pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id:
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
)
else:
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
)
image = self.pipeline(
prompt=self.prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=self.num_inference_steps,
guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = image[0, -3:, -3:, -1].flatten()
if "Canny" in lora_ckpt_id:
expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484])
else:
expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3

View File

@@ -266,7 +266,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
def nightly(test_case):