mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
🎨 fix xl playground device (#8550)
* 🎨 fix xl playground device * 🎨 run `make fix-copies` * 🎨 run `make fix-copies` * edit xl_controlnet_img2img file * edit playground img2img test slow * Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -949,8 +949,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=self.device, dtype=dtype)
|
||||
latents_mean = latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=dtype)
|
||||
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
@@ -723,8 +723,8 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=self.device, dtype=dtype)
|
||||
latents_mean = latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=dtype)
|
||||
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
@@ -31,6 +32,7 @@ from transformers import (
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
EDMDPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LCMScheduler,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
@@ -39,7 +41,9 @@ from diffusers import (
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -776,3 +780,54 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_xl_img2img_playground(self):
|
||||
torch.manual_seed(0)
|
||||
model_path = "playgroundai/playground-v2.5-1024px-aesthetic"
|
||||
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, variant="fp16", add_watermarker=False
|
||||
)
|
||||
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
sd_pipe.scheduler = EDMDPMSolverMultistepScheduler.from_config(
|
||||
sd_pipe.scheduler.config, use_karras_sigmas=True
|
||||
)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
|
||||
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
|
||||
|
||||
init_image = load_image(url).convert("RGB")
|
||||
|
||||
image = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=8.0,
|
||||
image=init_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 1024, 1024, 3)
|
||||
|
||||
expected_slice = np.array([0.3519, 0.3149, 0.3364, 0.3505, 0.3402, 0.3371, 0.3554, 0.3495, 0.3333])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user