mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SD Img2Img] resize source images to multiple of 8 instead of 32 (#1571)
* [Stable Diffusion Img2Img] resize source images to integer multiple of 8 instead of 32 * [Alt Diffusion Img2Img] resize source images to multiple of 8 instead of 32 * [Img2Img] fix AltDiffusion Img2Img resolution test * [Img2Img] add Stable Diffusion Img2Img resolution test * [Cycle Diffusion] round resolution to multiplies of 8 instead of 32 * [ONNX SD Img2Img] round resolution to multiplies of 64 instead of 32 * [SD Depth2Img] round resolution to multiplies of 8 instead of 32 * [Repaint] round resolution to multiplies of 8 instead of 32 * fix make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
135567f18e
commit
9b37ed33b5
@@ -80,7 +80,7 @@ def preprocess(image):
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
|
||||
@@ -38,7 +38,7 @@ def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
|
||||
@@ -44,7 +44,7 @@ def preprocess(image):
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
|
||||
@@ -32,7 +32,7 @@ from . import StableDiffusionPipelineOutput
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
@@ -41,7 +41,7 @@ def preprocess(image):
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
|
||||
@@ -49,7 +49,7 @@ def preprocess(image):
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
|
||||
@@ -84,7 +84,7 @@ def preprocess(image):
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
|
||||
@@ -207,6 +207,43 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
# resize to resolution that is divisible by 8 but not 16 or 32
|
||||
init_image = init_image.resize((760, 504))
|
||||
|
||||
model_id = "BAAI/AltDiffusion"
|
||||
pipe = AltDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
image_slice = image[255:258, 383:386, -1]
|
||||
|
||||
assert image.shape == (504, 760, 3)
|
||||
expected_slice = np.array([0.3252, 0.3340, 0.3418, 0.3263, 0.3346, 0.3300, 0.3163, 0.3470, 0.3427])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -333,6 +333,42 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
# make sure that less than 2.2 GB is allocated
|
||||
assert mem_bytes < 2.2 * 10**9
|
||||
|
||||
def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
# resize to resolution that is divisible by 8 but not 16 or 32
|
||||
init_image = init_image.resize((760, 504))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
image_slice = image[255:258, 383:386, -1]
|
||||
|
||||
assert image.shape == (504, 760, 3)
|
||||
expected_slice = np.array([0.7124, 0.7105, 0.6993, 0.7140, 0.7106, 0.6945, 0.7198, 0.7172, 0.7031])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user