mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
change flux lora integration tests to use pytest
This commit is contained in:
@@ -16,7 +16,6 @@ import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -26,7 +25,7 @@ from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils import load_image, logging
|
||||
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
@@ -752,7 +751,7 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
class TestFluxLoRAIntegration:
|
||||
"""internal note: The integration slices were obtained on audace.
|
||||
|
||||
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
|
||||
@@ -762,25 +761,25 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self, torch_device):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
del self.pipeline
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_flux_the_last_ben(self):
|
||||
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
def test_flux_the_last_ben(self, pipeline):
|
||||
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "jon snow eating pizza with ketchup"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.0,
|
||||
@@ -792,13 +791,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
def test_flux_kohya(self):
|
||||
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
def test_flux_kohya(self, pipeline):
|
||||
pipeline.load_lora_weights("Norod78/brain-slug-flux")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "The cat with a brain slug earring"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
@@ -810,13 +809,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
def test_flux_kohya_with_text_encoder(self):
|
||||
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
def test_flux_kohya_with_text_encoder(self, pipeline):
|
||||
pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "optimus is cleaning the house with broomstick"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
@@ -828,19 +827,18 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
def test_flux_kohya_embedders_conversion(self):
|
||||
def test_flux_kohya_embedders_conversion(self, pipeline):
|
||||
"""Test that embedders load without throwing errors"""
|
||||
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
self.pipeline.unload_lora_weights()
|
||||
assert True
|
||||
pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
def test_flux_xlabs(self):
|
||||
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
def test_flux_xlabs(self, pipeline):
|
||||
pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
@@ -852,15 +850,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
def test_flux_xlabs_load_lora_with_single_blocks(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
|
||||
)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline):
|
||||
pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
prompt = "a wizard mouse playing chess"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
@@ -873,3 +869,110 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class TestFluxControlLoRAIntegration:
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds."
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self, torch_device):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lora_ckpt_id",
|
||||
[
|
||||
"black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
"black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
],
|
||||
)
|
||||
def test_lora(self, pipeline, lora_ckpt_id):
|
||||
pipeline.load_lora_weights(lora_ckpt_id)
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
|
||||
)
|
||||
else:
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
|
||||
)
|
||||
|
||||
image = pipeline(
|
||||
prompt=self.prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = image[0, -3:, -3:, -1].flatten()
|
||||
if "Canny" in lora_ckpt_id:
|
||||
expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516])
|
||||
else:
|
||||
expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lora_ckpt_id",
|
||||
[
|
||||
"black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
"black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
],
|
||||
)
|
||||
def test_lora_with_turbo(self, pipeline, lora_ckpt_id):
|
||||
pipeline.load_lora_weights(lora_ckpt_id)
|
||||
pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
|
||||
)
|
||||
else:
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
|
||||
)
|
||||
|
||||
image = self.pipeline(
|
||||
prompt=self.prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = image[0, -3:, -3:, -1].flatten()
|
||||
if "Canny" in lora_ckpt_id:
|
||||
expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484])
|
||||
else:
|
||||
expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -266,7 +266,7 @@ def slow(test_case):
|
||||
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
|
||||
|
||||
|
||||
def nightly(test_case):
|
||||
|
||||
Reference in New Issue
Block a user