From 3fca52022fe0ea9aaf0a0ea8a0fc13308bf69a9f Mon Sep 17 00:00:00 2001 From: Dong <13023695951@163.com> Date: Mon, 24 Jun 2024 19:19:55 +0800 Subject: [PATCH] :art: fix xl playground device (#8550) * :art: fix xl playground device * :art: run `make fix-copies` * :art: run `make fix-copies` * edit xl_controlnet_img2img file * edit playground img2img test slow * Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py --------- Co-authored-by: Sayak Paul --- .../pipeline_controlnet_sd_xl_img2img.py | 4 +- .../pipeline_stable_diffusion_xl_img2img.py | 4 +- .../test_stable_diffusion_xl_img2img.py | 55 +++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index fedcfdb420..9d0d784e95 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -949,8 +949,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline( init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=self.device, dtype=dtype) - latents_std = latents_std.to(device=self.device, dtype=dtype) + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: init_latents = self.vae.config.scaling_factor * init_latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index e9a9b93863..1283bbccf3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -723,8 +723,8 @@ class StableDiffusionXLImg2ImgPipeline( init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=self.device, dtype=dtype) - latents_std = latents_std.to(device=self.device, dtype=dtype) + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: init_latents = self.vae.config.scaling_factor * init_latents diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index cb338e5e56..5b86dd0896 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import random import unittest @@ -31,6 +32,7 @@ from transformers import ( from diffusers import ( AutoencoderKL, AutoencoderTiny, + EDMDPMSolverMultistepScheduler, EulerDiscreteScheduler, LCMScheduler, StableDiffusionXLImg2ImgPipeline, @@ -39,7 +41,9 @@ from diffusers import ( from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, + load_image, require_torch_gpu, + slow, torch_device, ) @@ -776,3 +780,54 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( def test_save_load_optional_components(self): self._test_save_load_optional_components() + + +@slow +class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_xl_img2img_playground(self): + torch.manual_seed(0) + model_path = "playgroundai/playground-v2.5-1024px-aesthetic" + + sd_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + model_path, torch_dtype=torch.float16, variant="fp16", add_watermarker=False + ) + + sd_pipe.enable_model_cpu_offload() + sd_pipe.scheduler = EDMDPMSolverMultistepScheduler.from_config( + sd_pipe.scheduler.config, use_karras_sigmas=True + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "a photo of an astronaut riding a horse on mars" + + url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + init_image = load_image(url).convert("RGB") + + image = sd_pipe( + prompt, + num_inference_steps=30, + guidance_scale=8.0, + image=init_image, + height=1024, + width=1024, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 1024, 1024, 3) + + expected_slice = np.array([0.3519, 0.3149, 0.3364, 0.3505, 0.3402, 0.3371, 0.3554, 0.3495, 0.3333]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2