diff --git a/MANIFEST.in b/MANIFEST.in index b2a3d6cff1..b22fe1a28a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ +include LICENSE include src/diffusers/utils/model_card_template.md diff --git a/README.md b/README.md index f2abe1978a..734cba6efd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@


- +

diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 3ccdff9ba2..7a5c877c62 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -50,6 +50,7 @@ if is_transformers_available(): INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" +DUMMY_MODULES_FOLDER = "diffusers.utils" logger = logging.get_logger(__name__) @@ -473,9 +474,20 @@ class DiffusionPipeline(ConfigMixin): if issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] - load_method = getattr(class_obj, load_method_name) + if load_method_name is None: + none_module = class_obj.__module__ + if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module: + # call class_obj for nice error message of missing requirements + class_obj() + raise ValueError( + f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" + f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." + ) + + load_method = getattr(class_obj, load_method_name) loading_kwargs = {} + if issubclass(class_obj, torch.nn.Module): loading_kwargs["torch_dtype"] = torch_dtype if issubclass(class_obj, diffusers.OnnxRuntimeModel): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 72e15f4f90..3e5ac4b335 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -195,6 +195,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): """ if isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif isinstance(prompt, list): batch_size = len(prompt) else: @@ -284,8 +285,23 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many init images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = len(prompt) // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py index 51ee3b1848..14830bca28 100644 --- a/src/diffusers/utils/dummy_flax_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -9,3 +9,11 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 4ab14f752c..708022d85b 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -10,6 +10,14 @@ class FlaxModelMixin(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxUNet2DConditionModel(metaclass=DummyObject): _backends = ["flax"] @@ -17,6 +25,14 @@ class FlaxUNet2DConditionModel(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxAutoencoderKL(metaclass=DummyObject): _backends = ["flax"] @@ -24,6 +40,14 @@ class FlaxAutoencoderKL(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] @@ -31,6 +55,14 @@ class FlaxDiffusionPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] @@ -38,6 +70,14 @@ class FlaxDDIMScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxDDPMScheduler(metaclass=DummyObject): _backends = ["flax"] @@ -45,6 +85,14 @@ class FlaxDDPMScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"] @@ -52,6 +100,14 @@ class FlaxKarrasVeScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxLMSDiscreteScheduler(metaclass=DummyObject): _backends = ["flax"] @@ -59,6 +115,14 @@ class FlaxLMSDiscreteScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"] @@ -66,6 +130,14 @@ class FlaxPNDMScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxSchedulerMixin(metaclass=DummyObject): _backends = ["flax"] @@ -73,9 +145,25 @@ class FlaxSchedulerMixin(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 531c0b7766..ee748f5b1d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -10,6 +10,14 @@ class ModelMixin(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] @@ -17,6 +25,14 @@ class AutoencoderKL(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] @@ -24,6 +40,14 @@ class UNet2DConditionModel(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class UNet2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -31,6 +55,14 @@ class UNet2DModel(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class VQModel(metaclass=DummyObject): _backends = ["torch"] @@ -38,6 +70,14 @@ class VQModel(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) @@ -73,6 +113,14 @@ class DiffusionPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class DDIMPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -80,6 +128,14 @@ class DDIMPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class DDPMPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -87,6 +143,14 @@ class DDPMPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class KarrasVePipeline(metaclass=DummyObject): _backends = ["torch"] @@ -94,6 +158,14 @@ class KarrasVePipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class LDMPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -101,6 +173,14 @@ class LDMPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -108,6 +188,14 @@ class PNDMPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class ScoreSdeVePipeline(metaclass=DummyObject): _backends = ["torch"] @@ -115,6 +203,14 @@ class ScoreSdeVePipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -122,6 +218,14 @@ class DDIMScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class DDPMScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -129,6 +233,14 @@ class DDPMScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class KarrasVeScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -136,6 +248,14 @@ class KarrasVeScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class PNDMScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -143,6 +263,14 @@ class PNDMScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class SchedulerMixin(metaclass=DummyObject): _backends = ["torch"] @@ -150,6 +278,14 @@ class SchedulerMixin(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class ScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -157,9 +293,25 @@ class ScoreSdeVeScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class EMAModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) diff --git a/src/diffusers/utils/dummy_torch_and_scipy_objects.py b/src/diffusers/utils/dummy_torch_and_scipy_objects.py index 49c8956483..13f17349bb 100644 --- a/src/diffusers/utils/dummy_torch_and_scipy_objects.py +++ b/src/diffusers/utils/dummy_torch_and_scipy_objects.py @@ -9,3 +9,11 @@ class LMSDiscreteScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "scipy"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "scipy"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "scipy"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py index 967e231d87..d099b83729 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -9,3 +9,11 @@ class StableDiffusionOnnxPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6e4ab48c33..ce9a02bca1 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -10,6 +10,14 @@ class LDMTextToImagePipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -17,6 +25,14 @@ class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + class StableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -24,9 +40,25 @@ class StableDiffusionInpaintPipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + class StableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 30beb033fc..8004241ac1 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -492,6 +492,12 @@ class PipelineFastTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_from_pretrained_error_message_uninstalled_packages(self): + # TODO(Patrick, Pedro) - need better test here for the future + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-lms-pipe") + assert isinstance(pipe, StableDiffusionPipeline) + assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + def test_stable_diffusion_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet @@ -698,6 +704,48 @@ class PipelineFastTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_img2img_multiple_init_images(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = 2 * ["A painting of a squirrel eating a burger"] + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + prompt, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + + image = output.images + + image_slice = image[-1, -3:, -3:, -1] + + assert image.shape == (2, 32, 32, 3) + expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_img2img_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet diff --git a/utils/check_dummies.py b/utils/check_dummies.py index 60c954ad2f..82d54b2a29 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -38,6 +38,14 @@ class {0}(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, {1}) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, {1}) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, {1}) """