1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix/update the LDM pipeline and tests (#1743)

* Fix/update LDM tests

* batched generators
This commit is contained in:
Anton Lozhkov
2022-12-18 11:49:53 +01:00
committed by GitHub
parent 08cc36ddff
commit c2a38ef9df
2 changed files with 135 additions and 97 deletions

View File

@@ -128,29 +128,42 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
)
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
if self.device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu").to(self.device)
rand_device = "cpu" if self.device.type == "mps" else self.device
if isinstance(generator, list):
latents_shape = (1,) + latents_shape[1:]
latents = [
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0)
else:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
)
latents = latents.to(self.device)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = latents.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)

View File

@@ -13,24 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
from diffusers.utils.testing_utils import require_torch, slow, torch_device
from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, slow, torch_device
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class LDMTextToImagePipelineFastTests(unittest.TestCase):
@property
def dummy_cond_unet(self):
class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LDMTextToImagePipeline
test_cpu_offload = False
def get_dummy_components(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
@@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model
@property
def dummy_vae(self):
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
model = AutoencoderKL(
block_out_channels=[32, 64],
vae = AutoencoderKL(
block_out_channels=(32, 64),
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D"),
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D"),
latent_channels=4,
)
return model
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
@@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config)
def test_inference_text2img(self):
if torch_device != "cpu":
return
unet = self.dummy_cond_unet
scheduler = DDIMScheduler()
vae = self.dummy_vae
bert = self.dummy_text_encoder
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
components = {
"unet": unet,
"scheduler": scheduler,
"vqvae": vae,
"bert": text_encoder,
"tokenizer": tokenizer,
}
return components
prompt = "A painting of a squirrel eating a burger"
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy"
).images
def test_inference_text2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image = ldm(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy"
).images
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ldm(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="numpy",
return_dict=False,
)[0]
components = self.get_dummy_components()
pipe = LDMTextToImagePipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 16, 16, 3)
expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.59450, 0.64078, 0.55509, 0.51229, 0.69640, 0.36960, 0.59296, 0.60801, 0.49332])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow
@require_torch
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
def test_inference_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
@require_torch_gpu
class LDMTextToImagePipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
prompt = "A painting of a squirrel eating a burger"
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
def test_ldm_default_ddim(self):
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device)
pipe.set_progress_bar_config(disable=None)
image = ldm(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy"
).images
image_slice = image[0, -3:, -3:, -1]
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.51825, 0.52850, 0.52543, 0.54258, 0.52304, 0.52569, 0.54363, 0.55276, 0.56878])
max_diff = np.abs(expected_slice - image_slice).max()
assert max_diff < 1e-3
def test_inference_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
@nightly
@require_torch_gpu
class LDMTextToImagePipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"latents": latents,
"generator": generator,
"num_inference_steps": 50,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
def test_ldm_default_ddim(self):
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_slice = image[0, -3:, -3:, -1]
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3