mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix/update the LDM pipeline and tests (#1743)
* Fix/update LDM tests * batched generators
This commit is contained in:
@@ -128,29 +128,42 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
|
||||
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu").to(self.device)
|
||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
||||
|
||||
if isinstance(generator, list):
|
||||
latents_shape = (1,) + latents_shape[1:]
|
||||
latents = [
|
||||
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
@@ -13,24 +13,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, slow, torch_device
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class LDMTextToImagePipelineFastTests(unittest.TestCase):
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LDMTextToImagePipeline
|
||||
test_cpu_offload = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
@@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=(32, 64),
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D"),
|
||||
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D"),
|
||||
latent_channels=4,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
@@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
def test_inference_text2img(self):
|
||||
if torch_device != "cpu":
|
||||
return
|
||||
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vqvae": vae,
|
||||
"bert": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy"
|
||||
).images
|
||||
def test_inference_text2img(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
image = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy"
|
||||
).images
|
||||
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
image_from_tuple = ldm(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="numpy",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
components = self.get_dummy_components()
|
||||
pipe = LDMTextToImagePipeline(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 16, 16, 3)
|
||||
expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
expected_slice = np.array([0.59450, 0.64078, 0.55509, 0.51229, 0.69640, 0.36960, 0.59296, 0.60801, 0.49332])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_text2img(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
@require_torch_gpu
|
||||
class LDMTextToImagePipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
|
||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"latents": latents,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
def test_ldm_default_ddim(self):
|
||||
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy"
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
expected_slice = np.array([0.51825, 0.52850, 0.52543, 0.54258, 0.52304, 0.52569, 0.54363, 0.55276, 0.56878])
|
||||
max_diff = np.abs(expected_slice - image_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_inference_text2img_fast(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class LDMTextToImagePipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
|
||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"latents": latents,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 50,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
def test_ldm_default_ddim(self):
|
||||
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images[0]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user