1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

make pipelines tests device-agnostic (part2) (#9400)

* enable on xpu

* add 1 more

* add one more

* enable more

* add 1 more

* add more

* enable 1

* enable more cases

* enable

* enable

* update comment

* one more

* enable 1

* add more cases

* enable xpu

* add one more caswe

* add more cases

* add 1

* add more

* add more cases

* add case

* enable

* add more

* add more

* add more

* enbale more

* add more

* update code

* update test marker

* add skip back

* update comment

* remove single files

* remove

* style

* add

* revert

* reformat

* enable

* enable esingle g

* add 2 more

* update decorator

* update

* update

* update

* Update tests/pipelines/deepfloyd_if/test_if.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update tests/pipelines/animatediff/test_animatediff_controlnet.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update tests/pipelines/animatediff/test_animatediff.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update tests/pipelines/animatediff/test_animatediff_controlnet.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* update float16

* no unitest.skipt

* update

* apply style check

* adapt style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Fanli Lin
2024-11-29 14:03:41 +08:00
committed by GitHub
parent fdec8bd675
commit 6b288ec44d
16 changed files with 98 additions and 71 deletions

View File

@@ -156,14 +156,14 @@ class SDSingleFileTesterMixin:
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None)
sf_pipe.unet.set_attn_processor(AttnProcessor())
sf_pipe.enable_model_cpu_offload()
sf_pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
image_single_file = sf_pipe(**inputs).images[0]
pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]

View File

@@ -22,9 +22,11 @@ from diffusers import (
ControlNetModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -32,7 +34,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class ControlNetModelSingleFileTests(unittest.TestCase):
model_class = ControlNetModel
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
@@ -41,12 +43,12 @@ class ControlNetModelSingleFileTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)

View File

@@ -21,9 +21,11 @@ import torch
from diffusers import StableCascadeUNet
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -33,17 +35,17 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableCascadeUNetSingleFileTest(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_single_file_components_stage_b(self):
model_single_file = StableCascadeUNet.from_single_file(

View File

@@ -22,10 +22,11 @@ from diffusers import (
AutoencoderKL,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -35,7 +36,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class AutoencoderKLSingleFileTests(unittest.TestCase):
model_class = AutoencoderKL
ckpt_path = (
@@ -48,12 +49,12 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

View File

@@ -8,9 +8,10 @@ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -27,7 +28,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
@@ -41,12 +42,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -8,10 +8,12 @@ from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import (
@@ -26,7 +28,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
@@ -36,12 +38,12 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self):
control_image = load_image(

View File

@@ -8,10 +8,12 @@ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import (
@@ -26,7 +28,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
@@ -40,12 +42,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self):
control_image = load_image(

View File

@@ -8,9 +8,11 @@ from diffusers import (
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -20,7 +22,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = (
@@ -34,12 +36,12 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -63,7 +65,7 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
@@ -73,12 +75,12 @@ class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDS
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -8,9 +8,11 @@ from diffusers import (
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -20,7 +22,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
@@ -30,12 +32,12 @@ class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSin
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -78,7 +80,7 @@ class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSin
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = (
@@ -90,12 +92,12 @@ class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDS
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -7,9 +7,11 @@ import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import (
@@ -23,7 +25,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = (
@@ -37,12 +39,12 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -95,12 +97,12 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -8,10 +8,12 @@ from diffusers import (
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -21,7 +23,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionUpscalePipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
@@ -31,12 +33,12 @@ class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSin
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_single_file_format_inference_is_same_as_pretrained(self):
image = load_image(

View File

@@ -11,10 +11,12 @@ from diffusers import (
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import (
@@ -29,7 +31,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLAdapterPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
@@ -41,12 +43,12 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self):
prompt = "toy"

View File

@@ -8,9 +8,10 @@ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -26,7 +27,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLControlNetPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
@@ -38,12 +39,12 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -9,10 +9,12 @@ from diffusers import (
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import SDXLSingleFileTesterMixin
@@ -22,7 +24,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
@@ -34,12 +36,12 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -63,7 +65,7 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests(unittest.TestCase):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = (

View File

@@ -5,9 +5,11 @@ import torch
from diffusers import StableDiffusionXLInstructPix2PixPipeline
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -15,7 +17,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
@@ -25,12 +27,12 @@ class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -7,9 +7,11 @@ from diffusers import (
StableDiffusionXLPipeline,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import SDXLSingleFileTesterMixin
@@ -19,7 +21,7 @@ enable_full_determinism()
@slow
@require_torch_gpu
@require_torch_accelerator
class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
@@ -31,12 +33,12 @@ class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingle
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)