diff --git a/.github/workflows/pr_test_peft_backend.yml b/.github/workflows/pr_test_peft_backend.yml index 97aea28bdb..8cc4eb6e59 100644 --- a/.github/workflows/pr_test_peft_backend.yml +++ b/.github/workflows/pr_test_peft_backend.yml @@ -59,7 +59,7 @@ jobs: - name: Run fast PyTorch LoRA CPU tests with PEFT backend run: | - python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v \ --make-reports=tests_${{ matrix.config.report }} \ tests/lora/test_lora_layers_peft.py diff --git a/docker/diffusers-pytorch-compile-cuda/Dockerfile b/docker/diffusers-pytorch-compile-cuda/Dockerfile index da9f372bd6..e1e63758a4 100644 --- a/docker/diffusers-pytorch-compile-cuda/Dockerfile +++ b/docker/diffusers-pytorch-compile-cuda/Dockerfile @@ -40,7 +40,6 @@ RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \ numpy \ scipy \ tensorboard \ - transformers \ - omegaconf - + transformers + CMD ["/bin/bash"] diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile index 877bc6840e..7fc3d8ced9 100644 --- a/docker/diffusers-pytorch-cuda/Dockerfile +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ scipy \ tensorboard \ transformers \ - omegaconf \ pytorch-lightning CMD ["/bin/bash"] diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile index 003f8e1165..8f2619c623 100644 --- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile +++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile @@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ scipy \ tensorboard \ transformers \ - omegaconf \ xformers CMD ["/bin/bash"] diff --git a/docs/source/en/api/models/autoencoderkl.md b/docs/source/en/api/models/autoencoderkl.md index 72427ab30e..3534c8250d 100644 --- a/docs/source/en/api/models/autoencoderkl.md +++ b/docs/source/en/api/models/autoencoderkl.md @@ -33,6 +33,9 @@ model = AutoencoderKL.from_single_file(url) ## AutoencoderKL [[autodoc]] AutoencoderKL + - decode + - encode + - all ## AutoencoderKLOutput diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index fb38687e88..4e1670df77 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -235,6 +235,62 @@ export_to_gif(frames, "animation.gif") +## Using FreeInit + +[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu. + +FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper. + +The following example demonstrates the usage of FreeInit. + +```python +import torch +from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler +from diffusers.utils import export_to_gif + +adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") +model_id = "SG161222/Realistic_Vision_V5.1_noVAE" +pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda") +pipe.scheduler = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + beta_schedule="linear", + clip_sample=False, + timestep_spacing="linspace", + steps_offset=1 +) + +# enable memory savings +pipe.enable_vae_slicing() +pipe.enable_vae_tiling() + +# enable FreeInit +# Refer to the enable_free_init documentation for a full list of configurable parameters +pipe.enable_free_init(method="butterworth", use_fast_sampling=True) + +# run inference +output = pipe( + prompt="a panda playing a guitar, on a boat, in the ocean, high quality", + negative_prompt="bad quality, worse quality", + num_frames=16, + guidance_scale=7.5, + num_inference_steps=20, + generator=torch.Generator("cpu").manual_seed(666), +) + +# disable FreeInit +pipe.disable_free_init() + +frames = output.frames[0] +export_to_gif(frames, "animation.gif") +``` + + + +FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models). + + + Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. @@ -248,6 +304,8 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) - __call__ - enable_freeu - disable_freeu + - enable_free_init + - disable_free_init - enable_vae_slicing - disable_vae_slicing - enable_vae_tiling diff --git a/docs/source/en/installation.md b/docs/source/en/installation.md index 3bf1d46fd0..8e0eddd175 100644 --- a/docs/source/en/installation.md +++ b/docs/source/en/installation.md @@ -37,8 +37,10 @@ source .env/bin/activate You should also install 🤗 Transformers because 🤗 Diffusers relies on its models: + +Note - PyTorch only supports Python 3.8 - 3.11 on Windows. ```bash pip install diffusers["torch"] transformers ``` diff --git a/docs/source/en/using-diffusers/controlnet.md b/docs/source/en/using-diffusers/controlnet.md index e7f6eb2756..ac4bfa4472 100644 --- a/docs/source/en/using-diffusers/controlnet.md +++ b/docs/source/en/using-diffusers/controlnet.md @@ -429,7 +429,7 @@ image = pipe( make_image_grid([original_image, canny_image, image], rows=1, cols=3) ``` -### MultiControlNet +## MultiControlNet diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index d9d4a675dd..0ef90c6dd9 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -344,7 +344,8 @@ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-a IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it. ```py -from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection +from diffusers import AutoPipelineForText2Image +from transformers import CLIPVisionModelWithProjection import torch image_encoder = CLIPVisionModelWithProjection.from_pretrained( diff --git a/docs/source/en/using-diffusers/sdxl.md b/docs/source/en/using-diffusers/sdxl.md index 25b581fc6f..906eb0dbbe 100644 --- a/docs/source/en/using-diffusers/sdxl.md +++ b/docs/source/en/using-diffusers/sdxl.md @@ -26,7 +26,7 @@ Before you begin, make sure you have the following libraries installed: ```py # uncomment to install the necessary libraries in Colab -#!pip install -q diffusers transformers accelerate omegaconf invisible-watermark>=0.2.0 +#!pip install -q diffusers transformers accelerate invisible-watermark>=0.2.0 ``` diff --git a/docs/source/en/using-diffusers/sdxl_turbo.md b/docs/source/en/using-diffusers/sdxl_turbo.md index 99e1c7000e..ceca5729b1 100644 --- a/docs/source/en/using-diffusers/sdxl_turbo.md +++ b/docs/source/en/using-diffusers/sdxl_turbo.md @@ -23,7 +23,7 @@ Before you begin, make sure you have the following libraries installed: ```py # uncomment to install the necessary libraries in Colab -#!pip install -q diffusers transformers accelerate omegaconf +#!pip install -q diffusers transformers accelerate ``` ## Load model checkpoints diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index ddd8114ae4..e35630e3e8 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -38,7 +38,7 @@ from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig +from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose @@ -58,15 +58,17 @@ from diffusers import ( ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.utils import ( check_min_version, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, + convert_unet_state_dict_to_peft, is_wandb_available, ) from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -1277,7 +1279,7 @@ def main(args): for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param = param.to(dtype=torch.float32) + param.data = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_one.append(param) else: @@ -1286,22 +1288,16 @@ def main(args): for name, param in text_encoder_two.named_parameters(): if "token_embedding" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param = param.to(dtype=torch.float32) + param.data = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_two.append(param) else: param.requires_grad = False - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - models = [unet] - if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -1313,14 +1309,14 @@ def main(args): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): if args.train_text_encoder: text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): if args.train_text_encoder: text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) @@ -1348,27 +1344,44 @@ def main(args): while len(models) > 0: model = models.pop() - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ - ) + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) - text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ - ) + if args.train_text_encoder: + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.extend([text_encoder_one_, text_encoder_two_]) + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1383,6 +1396,13 @@ def main(args): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + cast_training_params(models, dtype=torch.float32) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: @@ -1705,19 +1725,19 @@ def main(args): num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) elif args.train_text_encoder_ti: # args.train_text_encoder_ti num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) - + # flag used for textual inversion + pivoted = False for epoch in range(first_epoch, args.num_train_epochs): # if performing any kind of optimization of text_encoder params if args.train_text_encoder or args.train_text_encoder_ti: if epoch == num_train_epochs_text_encoder: print("PIVOT HALFWAY", epoch) # stopping optimization of text_encoder params - # re setting the optimizer to optimize only on unet params - optimizer.param_groups[1]["lr"] = 0.0 - optimizer.param_groups[2]["lr"] = 0.0 + # this flag is used to reset the optimizer to optimize only on unet params + pivoted = True else: - # still optimizng the text encoder + # still optimizing the text encoder text_encoder_one.train() text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works @@ -1727,6 +1747,12 @@ def main(args): unet.train() for step, batch in enumerate(train_dataloader): + if pivoted: + # stopping optimization of text_encoder params + # re setting the optimizer to optimize only on unet params + optimizer.param_groups[1]["lr"] = 0.0 + optimizer.param_groups[2]["lr"] = 0.0 + with accelerator.accumulate(unet): prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1865,8 +1891,7 @@ def main(args): # every step, we reset the embeddings to the original embeddings. if args.train_text_encoder_ti: - for idx, text_encoder in enumerate(text_encoders): - embedding_handler.retract_embeddings() + embedding_handler.retract_embeddings() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/community/README.md b/examples/community/README.md index 7d8d190f03..45d393f84f 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -58,6 +58,7 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap | Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) | | Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) | | StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | +| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -2989,7 +2990,7 @@ pipe = DiffusionPipeline.from_pretrained( custom_pipeline="pipeline_animatediff_controlnet", ).to(device="cuda", dtype=torch.float16) pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( - model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1 + model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear", ) pipe.enable_vae_slicing() @@ -3005,7 +3006,7 @@ result = pipe( width=512, height=768, conditioning_frames=conditioning_frames, - num_inference_steps=12, + num_inference_steps=20, ).frames[0] from diffusers.utils import export_to_gif @@ -3029,6 +3030,79 @@ export_to_gif(result.frames[0], "result.gif") +You can also use multiple controlnets at once! + +```python +import torch +from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter +from diffusers.pipelines import DiffusionPipeline +from diffusers.schedulers import DPMSolverMultistepScheduler +from PIL import Image + +motion_id = "guoyww/animatediff-motion-adapter-v1-5-2" +adapter = MotionAdapter.from_pretrained(motion_id) +controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16) +controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + +model_id = "SG161222/Realistic_Vision_V5.1_noVAE" +pipe = DiffusionPipeline.from_pretrained( + model_id, + motion_adapter=adapter, + controlnet=[controlnet1, controlnet2], + vae=vae, + custom_pipeline="pipeline_animatediff_controlnet", +).to(device="cuda", dtype=torch.float16) +pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear", +) +pipe.enable_vae_slicing() + +def load_video(file_path: str): + images = [] + + if file_path.startswith(('http://', 'https://')): + # If the file_path is a URL + response = requests.get(file_path) + response.raise_for_status() + content = BytesIO(response.content) + vid = imageio.get_reader(content) + else: + # Assuming it's a local file path + vid = imageio.get_reader(file_path) + + for frame in vid: + pil_image = Image.fromarray(frame) + images.append(pil_image) + + return images + +video = load_video("dance.gif") + +# You need to install it using `pip install controlnet_aux` +from controlnet_aux.processor import Processor + +p1 = Processor("openpose_full") +cn1 = [p1(frame) for frame in video] + +p2 = Processor("canny") +cn2 = [p2(frame) for frame in video] + +prompt = "astronaut in space, dancing" +negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" +result = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + conditioning_frames=[cn1, cn2], + num_inference_steps=20, +) + +from diffusers.utils import export_to_gif +export_to_gif(result.frames[0], "result.gif") +``` + ### DemoFusion This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973). @@ -3333,4 +3407,63 @@ images = pipe( # Disable StyleAligned if you do not wish to use it anymore pipe.disable_style_aligned() -``` \ No newline at end of file +``` + +### IP Adapter Face ID +IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded. +You need to install `insightface` and all its requirements to use this model. +You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. +You have to disable PEFT BACKEND in order to load weights. + +```py +import diffusers +diffusers.utils.USE_PEFT_BACKEND = False +import torch +from diffusers.utils import load_image +import cv2 +import numpy as np +from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler +from insightface.app import FaceAnalysis + + +noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, +) +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16) +pipeline = DiffusionPipeline.from_pretrained( + "SG161222/Realistic_Vision_V4.0_noVAE", + torch_dtype=torch.float16, + scheduler=noise_scheduler, + vae=vae, + custom_pipeline="ip_adapter_face_id" +) +pipeline.load_ip_adapter_face_id("h94/IP-Adapter-FaceID", "ip-adapter-faceid_sd15.bin") +pipeline.to("cuda") + +generator = torch.Generator(device="cpu").manual_seed(42) +num_images=2 + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png") + +app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(640, 640)) +image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) +faces = app.get(image) +image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) +images = pipeline( + prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower", + image_embeds=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704, + generator=generator +).images + +for i in range(num_images): + images[i].save(f"c{i}.png") +``` diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py new file mode 100644 index 0000000000..d9325742cf --- /dev/null +++ b/examples/community/ip_adapter_face_id.py @@ -0,0 +1,1525 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import FusedAttnProcessor2_0 +from diffusers.models.lora import LoRALinearLayer, adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + _get_model_file, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LoRAIPAdapterAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + lora_scale (`float`, defaults to 1.0): + the weight scale of LoRA. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAdapterAttnProcessor2_0(nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + lora_scale (`float`, defaults to 1.0): + the weight scale of LoRA. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAdapterFullImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from diffusers.models.attention import FeedForward + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.FloatTensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class IPAdapterFaceIDStableDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_name, **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + self._load_ip_adapter_weights(state_dict) + + def convert_ip_adapter_image_proj_to_diffusers(self, state_dict): + updated_state_dict = {} + clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in + norm_layer = "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim + + image_projection = IPAdapterFullImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + updated_state_dict[diffusers_name] = value + + image_projection.load_state_dict(updated_state_dict) + return image_projection + + def _load_ip_adapter_weights(self, state_dict): + from diffusers.models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + ) + + num_image_text_embeds = 4 + + self.unet.encoder_hid_proj = None + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + key_id = 0 + for name in self.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.unet.config.block_out_channels[block_id] + if cross_attention_dim is None or "motion_modules" in name: + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] + attn_module = self.unet + for n in name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, + out_features=attn_module.to_q.out_features, + rank=rank, + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, + out_features=attn_module.to_k.out_features, + rank=rank, + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, + out_features=attn_module.to_v.out_features, + rank=rank, + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=rank, + ) + ) + + value_dict = {} + for k, module in attn_module.named_children(): + index = "." + if not hasattr(module, "set_lora_layer"): + index = ".0." + module = module[0] + lora_layer = getattr(module, "lora_layer") + for lora_name, w in lora_layer.state_dict().items(): + value_dict.update( + { + f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][ + f"{key_id}.{k}_lora.{lora_name}" + ] + } + ) + + attn_module.load_state_dict(value_dict, strict=False) + attn_module.to(dtype=self.dtype, device=self.device) + key_id += 1 + else: + rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] + attn_processor_class = ( + LoRAIPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else LoRAIPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + rank=rank, + num_tokens=num_image_text_embeds, + ).to(dtype=self.dtype, device=self.device) + + value_dict = {} + for k, w in attn_procs[name].state_dict().items(): + value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) + + attn_procs[name].load_state_dict(value_dict) + key_id += 1 + + self.unet.set_attn_processor(attn_procs) + + # convert IP-Adapter Image Projection layers to diffusers + image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) + + self.unet.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) + self.unet.config.encoder_hid_dim_type = "ip_image_proj" + + def set_ip_adapter_scale(self, scale): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): + if isinstance(attn_processor, (LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0)): + attn_processor.scale = scale + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if image_embeds is not None: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to( + device=device, dtype=prompt_embeds.dtype + ) + negative_image_embeds = torch.zeros_like(image_embeds) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py index 785f1ee55e..cf0c66bb50 100644 --- a/examples/community/pipeline_animatediff_controlnet.py +++ b/examples/community/pipeline_animatediff_controlnet.py @@ -14,7 +14,7 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """ ... custom_pipeline="pipeline_animatediff_controlnet", ... ).to(device="cuda", dtype=torch.float16) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( - ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1 + ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear", ... ) >>> pipe.enable_vae_slicing() @@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """ ... height=768, ... conditioning_frames=conditioning_frames, ... num_inference_steps=12, - ... ).frames[0] + ... ) >>> from diffusers.utils import export_to_gif >>> export_to_gif(result.frames[0], "result.gif") @@ -151,7 +151,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, motion_adapter: MotionAdapter, - controlnet: Union[ControlNetModel, MultiControlNetModel], + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: Union[ DDIMScheduler, PNDMScheduler, @@ -166,6 +166,9 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix super().__init__() unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -488,6 +491,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix prompt, height, width, + num_frames, callback_steps, negative_prompt=None, prompt_embeds=None, @@ -557,31 +561,21 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): - if isinstance(image, list): - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) - else: - self.check_image(image, prompt, prompt_embeds) + if not isinstance(image, list): + raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}") + if len(image) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif len(image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." - ) - - for control_ in image: - for image_ in control_: - self.check_image(image_, prompt, prompt_embeds) + if not isinstance(image, list) or not isinstance(image[0], list): + raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}") + if len(image[0]) != num_frames: + raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}") + if any(len(img) != len(image[0]) for img in image): + raise ValueError("All conditioning frame batches for multicontrolnet must be same size") else: assert False @@ -913,6 +907,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix prompt=prompt, height=height, width=width, + num_frames=num_frames, callback_steps=callback_steps, negative_prompt=negative_prompt, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -1000,9 +995,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) - cond_prepared_frames.append(prepared_frame) - conditioning_frames = cond_prepared_frames else: assert False diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index a63654bc99..44a58fa2a8 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -51,7 +51,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import resolve_interpolation_mode +from diffusers.training_utils import cast_training_params, resolve_interpolation_mode from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -860,10 +860,8 @@ def main(args): # Make sure the trainable params are in float32. if args.mixed_precision == "fp16": - for param in unet.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 2d2629b2fd..3724e3d140 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -35,7 +35,7 @@ from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version from peft import LoraConfig -from peft.utils import get_peft_model_state_dict +from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset @@ -54,7 +54,13 @@ from diffusers import ( ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -892,10 +898,33 @@ def main(args): raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ - ) + + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + if args.train_text_encoder: + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.append(text_encoder_) + + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -910,6 +939,15 @@ def main(args): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.append(text_encoder) + + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 8f92e3b442..a995eb3043 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -34,7 +34,7 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version -from peft import LoraConfig +from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose @@ -53,8 +53,13 @@ from diffusers import ( ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -997,17 +1002,6 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - models = [unet] - if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) - def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model @@ -1064,17 +1058,36 @@ def main(args): raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ - ) + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) - text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ - ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.extend([text_encoder_one_, text_encoder_two_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1089,6 +1102,15 @@ def main(args): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: @@ -1506,6 +1528,7 @@ def main(args): else unet_lora_parameters ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() lr_scheduler.step() optimizer.zero_grad() diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 78cb7bc2f9..2af858cfd0 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -49,6 +49,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -489,6 +490,11 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -845,7 +851,7 @@ def main(): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -919,9 +925,9 @@ def main(): # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), - vae=accelerator.unwrap_model(vae), + unet=unwrap_model(unet), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -965,14 +971,14 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - vae=accelerator.unwrap_model(vae), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), unet=unet, revision=args.revision, variant=args.variant, diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index de59cb1f0b..cab16a6333 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -52,6 +52,7 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instru from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -531,6 +532,11 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -1044,8 +1050,12 @@ def main(): added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} model_pred = unet( - concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ).sample + concatenated_noisy_latents, + timesteps, + encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -1099,7 +1109,7 @@ def main(): progress_bar.set_postfix(**logs) ### BEGIN: Perform validation every `validation_epochs` steps - if global_step % args.validation_steps == 0 or global_step == 1: + if global_step % args.validation_steps == 0: if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" @@ -1115,7 +1125,7 @@ def main(): # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), + unet=unwrap_model(unet), text_encoder=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer=tokenizer_1, @@ -1177,7 +1187,7 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) diff --git a/examples/research_projects/realfill/requirements.txt b/examples/research_projects/realfill/requirements.txt index 5d69d84563..f6abdc6e7e 100644 --- a/examples/research_projects/realfill/requirements.txt +++ b/examples/research_projects/realfill/requirements.txt @@ -6,4 +6,4 @@ torch==2.0.1 torchvision>=0.16 ftfy==6.1.1 tensorboard==2.14.0 -Jinja2==3.1.2 +Jinja2==3.1.3 diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index e06fc227af..645f1f04e1 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -50,6 +50,7 @@ from diffusers import ( from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module MAX_SEQ_LENGTH = 77 @@ -926,6 +927,11 @@ def main(args): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -935,9 +941,9 @@ def main(args): " doing mixed precision training, copy of the weights should still be float32." ) - if accelerator.unwrap_model(t2iadapter).dtype != torch.float32: + if unwrap_model(t2iadapter).dtype != torch.float32: raise ValueError( - f"Controlnet loaded as datatype {accelerator.unwrap_model(t2iadapter).dtype}. {low_precision_error_string}" + f"Controlnet loaded as datatype {unwrap_model(t2iadapter).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, @@ -1198,7 +1204,8 @@ def main(args): encoder_hidden_states=batch["prompt_ids"], added_cond_kwargs=batch["unet_added_conditions"], down_block_additional_residuals=down_block_additional_residuals, - ).sample + return_dict=False, + )[0] # Denoise the latents denoised_latents = model_pred * (-sigmas) + noisy_latents @@ -1279,7 +1286,7 @@ def main(args): # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - t2iadapter = accelerator.unwrap_model(t2iadapter) + t2iadapter = unwrap_model(t2iadapter) t2iadapter.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 27dedc8f7f..906884eea7 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -43,9 +43,10 @@ from transformers import CLIPTextModel, CLIPTokenizer import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr +from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -466,10 +467,8 @@ def main(): # Add adapter and make sure the trainable params are in float32. unet.add_adapter(unet_lora_config) if args.mixed_precision == "fp16": - for param in unet.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -598,6 +597,11 @@ def main(): ] ) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] @@ -731,7 +735,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -746,7 +750,7 @@ def main(): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] if args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -811,7 +815,7 @@ def main(): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) - unwrapped_unet = accelerator.unwrap_model(unet) + unwrapped_unet = unwrap_model(unet) unet_lora_state_dict = convert_state_dict_to_diffusers( get_peft_model_state_dict(unwrapped_unet) ) @@ -839,7 +843,7 @@ def main(): # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), + unet=unwrap_model(unet), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -880,7 +884,7 @@ def main(): if accelerator.is_main_process: unet = unet.to(torch.float32) - unwrapped_unet = accelerator.unwrap_model(unet) + unwrapped_unet = unwrap_model(unet) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet)) StableDiffusionPipeline.save_lora_weights( save_directory=args.output_dir, diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 606a88f55b..6b13f75ead 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -51,9 +51,10 @@ from diffusers import ( ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr +from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -634,11 +634,13 @@ def main(args): models = [unet] if args.train_text_encoder: models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -650,13 +652,13 @@ def main(args): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) @@ -681,11 +683,11 @@ def main(args): while len(models) > 0: model = models.pop() - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -834,6 +836,9 @@ def main(args): for image in images: original_sizes.append((image.height, image.width)) image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) if args.center_crop: y1 = max(0, int(round((image.height - args.resolution) / 2.0))) x1 = max(0, int(round((image.width - args.resolution) / 2.0))) @@ -841,10 +846,6 @@ def main(args): else: y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - # flip - x1 = image.width - x1 - image = train_flip(image) crop_top_left = (y1, x1) crop_top_lefts.append(crop_top_left) image = train_transforms(image) @@ -1034,8 +1035,12 @@ def main(args): ) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions - ).sample + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -1128,9 +1133,9 @@ def main(args): pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one), - text_encoder_2=accelerator.unwrap_model(text_encoder_two), - unet=accelerator.unwrap_model(unet), + text_encoder=unwrap_model(text_encoder_one), + text_encoder_2=unwrap_model(text_encoder_two), + unet=unwrap_model(unet), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -1169,12 +1174,12 @@ def main(args): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: - text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_two = unwrap_model(text_encoder_two) text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one)) text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two)) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 0bb57b1f31..5ec27c4c49 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -44,16 +44,12 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - StableDiffusionXLPipeline, - UNet2DConditionModel, -) +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), output_hidden_states=True, + return_dict=False, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -842,6 +839,9 @@ def main(args): for image in images: original_sizes.append((image.height, image.width)) image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) if args.center_crop: y1 = max(0, int(round((image.height - args.resolution) / 2.0))) x1 = max(0, int(round((image.width - args.resolution) / 2.0))) @@ -849,10 +849,6 @@ def main(args): else: y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - # flip - x1 = image.width - x1 - image = train_flip(image) crop_top_left = (y1, x1) crop_top_lefts.append(crop_top_left) image = train_transforms(image) @@ -955,6 +951,12 @@ def main(args): if accelerator.is_main_process: accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) + # Function for unwraping if torch.compile() was used in accelerate. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1054,8 +1056,12 @@ def main(args): pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions - ).sample + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -1206,7 +1212,7 @@ def main(args): accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) diff --git a/scripts/conversion_ldm_uncond.py b/scripts/conversion_ldm_uncond.py index d2ebb3934b..8c22cc1ce8 100644 --- a/scripts/conversion_ldm_uncond.py +++ b/scripts/conversion_ldm_uncond.py @@ -1,13 +1,13 @@ import argparse -import OmegaConf import torch +import yaml from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel def convert_ldm_original(checkpoint_path, config_path, output_path): - config = OmegaConf.load(config_path) + config = yaml.safe_load(config_path) state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] keys = list(state_dict.keys()) @@ -25,8 +25,8 @@ def convert_ldm_original(checkpoint_path, config_path, output_path): if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = state_dict[key] - vqvae_init_args = config.model.params.first_stage_config.params - unet_init_args = config.model.params.unet_config.params + vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"] + unet_init_args = config["model"]["params"]["unet_config"]["params"] vqvae = VQModel(**vqvae_init_args).eval() vqvae.load_state_dict(first_stage_dict) @@ -35,10 +35,10 @@ def convert_ldm_original(checkpoint_path, config_path, output_path): unet.load_state_dict(unet_state_dict) noise_scheduler = DDIMScheduler( - timesteps=config.model.params.timesteps, + timesteps=config["model"]["params"]["timesteps"], beta_schedule="scaled_linear", - beta_start=config.model.params.linear_start, - beta_end=config.model.params.linear_end, + beta_start=config["model"]["params"]["linear_start"], + beta_end=config["model"]["params"]["linear_end"], clip_sample=False, ) diff --git a/scripts/convert_gligen_to_diffusers.py b/scripts/convert_gligen_to_diffusers.py index 816e4c112e..30d789b606 100644 --- a/scripts/convert_gligen_to_diffusers.py +++ b/scripts/convert_gligen_to_diffusers.py @@ -2,6 +2,7 @@ import argparse import re import torch +import yaml from transformers import ( CLIPProcessor, CLIPTextModel, @@ -28,8 +29,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( textenc_conversion_map, textenc_pattern, ) -from diffusers.utils import is_omegaconf_available -from diffusers.utils.import_utils import BACKENDS_MAPPING def convert_open_clip_checkpoint(checkpoint): @@ -370,52 +369,52 @@ def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=Fa def create_vae_config(original_config, image_size: int): - vae_params = original_config.autoencoder.params.ddconfig - _ = original_config.autoencoder.params.embed_dim + vae_params = original_config["autoencoder"]["params"]["ddconfig"] + _ = original_config["autoencoder"]["params"]["embed_dim"] - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], } return config def create_unet_config(original_config, image_size: int, attention_type): - unet_params = original_config.model.params - vae_params = original_config.autoencoder.params.ddconfig + unet_params = original_config["model"]["params"] + vae_params = original_config["autoencoder"]["params"]["ddconfig"] - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - head_dim = unet_params.num_heads if "num_heads" in unet_params else None + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: if head_dim is None: @@ -423,11 +422,11 @@ def create_unet_config(original_config, image_size: int, attention_type): config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, + "in_channels": unet_params["in_channels"], "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, - "cross_attention_dim": unet_params.context_dim, + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": unet_params["context_dim"], "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "attention_type": attention_type, @@ -445,11 +444,6 @@ def convert_gligen_to_diffusers( num_in_channels: int = None, device: str = None, ): - if not is_omegaconf_available(): - raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) - - from omegaconf import OmegaConf - if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) @@ -461,14 +455,14 @@ def convert_gligen_to_diffusers( else: print("global_step key not found in model") - original_config = OmegaConf.load(original_config_file) + original_config = yaml.safe_load(original_config_file) if num_in_channels is not None: original_config["model"]["params"]["in_channels"] = num_in_channels - num_train_timesteps = original_config.diffusion.params.timesteps - beta_start = original_config.diffusion.params.linear_start - beta_end = original_config.diffusion.params.linear_end + num_train_timesteps = original_config["diffusion"]["params"]["timesteps"] + beta_start = original_config["diffusion"]["params"]["linear_start"] + beta_end = original_config["diffusion"]["params"]["linear_end"] scheduler = DDIMScheduler( beta_end=beta_end, diff --git a/scripts/convert_if.py b/scripts/convert_if.py index 66d7f694c8..c4588f4b25 100644 --- a/scripts/convert_if.py +++ b/scripts/convert_if.py @@ -4,6 +4,7 @@ import os import numpy as np import torch +import yaml from torch.nn import functional as F from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer @@ -11,14 +12,6 @@ from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker -try: - from omegaconf import OmegaConf -except ImportError: - raise ImportError( - "OmegaConf is required to convert the IF checkpoints. Please install it with `pip install" " OmegaConf`." - ) - - def parse_args(): parser = argparse.ArgumentParser() @@ -143,8 +136,8 @@ def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safet def get_stage_1_unet(unet_config, unet_checkpoint_path): - original_unet_config = OmegaConf.load(unet_config) - original_unet_config = original_unet_config.params + original_unet_config = yaml.safe_load(unet_config) + original_unet_config = original_unet_config["params"] unet_diffusers_config = create_unet_diffusers_config(original_unet_config) @@ -215,11 +208,11 @@ def convert_safety_checker(p_head_path, w_head_path): def create_unet_diffusers_config(original_unet_config, class_embed_type=None): - attention_resolutions = parse_list(original_unet_config.attention_resolutions) - attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions] + attention_resolutions = parse_list(original_unet_config["attention_resolutions"]) + attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions] - channel_mult = parse_list(original_unet_config.channel_mult) - block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult] + channel_mult = parse_list(original_unet_config["channel_mult"]) + block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult] down_block_types = [] resolution = 1 @@ -227,7 +220,7 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None): for i in range(len(block_out_channels)): if resolution in attention_resolutions: block_type = "SimpleCrossAttnDownBlock2D" - elif original_unet_config.resblock_updown: + elif original_unet_config["resblock_updown"]: block_type = "ResnetDownsampleBlock2D" else: block_type = "DownBlock2D" @@ -241,17 +234,17 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None): for i in range(len(block_out_channels)): if resolution in attention_resolutions: block_type = "SimpleCrossAttnUpBlock2D" - elif original_unet_config.resblock_updown: + elif original_unet_config["resblock_updown"]: block_type = "ResnetUpsampleBlock2D" else: block_type = "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - head_dim = original_unet_config.num_head_channels + head_dim = original_unet_config["num_head_channels"] use_linear_projection = ( - original_unet_config.use_linear_in_transformer + original_unet_config["use_linear_in_transformer"] if "use_linear_in_transformer" in original_unet_config else False ) @@ -264,27 +257,27 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None): if class_embed_type is None: if "num_classes" in original_unet_config: - if original_unet_config.num_classes == "sequential": + if original_unet_config["num_classes"] == "sequential": class_embed_type = "projection" assert "adm_in_channels" in original_unet_config - projection_class_embeddings_input_dim = original_unet_config.adm_in_channels + projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"] else: raise NotImplementedError( - f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}" + f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}" ) config = { - "sample_size": original_unet_config.image_size, - "in_channels": original_unet_config.in_channels, + "sample_size": original_unet_config["image_size"], + "in_channels": original_unet_config["in_channels"], "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": original_unet_config.num_res_blocks, - "cross_attention_dim": original_unet_config.encoder_channels, + "layers_per_block": original_unet_config["num_res_blocks"], + "cross_attention_dim": original_unet_config["encoder_channels"], "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - "out_channels": original_unet_config.out_channels, + "out_channels": original_unet_config["out_channels"], "up_block_types": tuple(up_block_types), "upcast_attention": False, # TODO: guessing "cross_attention_norm": "group_norm", @@ -293,11 +286,11 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None): "act_fn": "gelu", } - if original_unet_config.use_scale_shift_norm: + if original_unet_config["use_scale_shift_norm"]: config["resnet_time_scale_shift"] = "scale_shift" if "encoder_dim" in original_unet_config: - config["encoder_hid_dim"] = original_unet_config.encoder_dim + config["encoder_hid_dim"] = original_unet_config["encoder_dim"] return config @@ -725,15 +718,15 @@ def parse_list(value): def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None): orig_path = unet_checkpoint_path - original_unet_config = OmegaConf.load(os.path.join(orig_path, "config.yml")) - original_unet_config = original_unet_config.params + original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml")) + original_unet_config = original_unet_config["params"] unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config) - unet_diffusers_config["time_embedding_dim"] = original_unet_config.model_channels * int( - original_unet_config.channel_mult.split(",")[-1] + unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int( + original_unet_config["channel_mult"].split(",")[-1] ) - if original_unet_config.encoder_dim != original_unet_config.encoder_channels: - unet_diffusers_config["encoder_hid_dim"] = original_unet_config.encoder_dim + if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]: + unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"] unet_diffusers_config["class_embed_type"] = "timestep" unet_diffusers_config["addition_embed_type"] = "text" @@ -742,16 +735,16 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071 unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071 unet_diffusers_config["only_cross_attention"] = ( - bool(original_unet_config.disable_self_attentions) + bool(original_unet_config["disable_self_attentions"]) if ( "disable_self_attentions" in original_unet_config - and isinstance(original_unet_config.disable_self_attentions, int) + and isinstance(original_unet_config["disable_self_attentions"], int) ) else True ) if sample_size is None: - unet_diffusers_config["sample_size"] = original_unet_config.image_size + unet_diffusers_config["sample_size"] = original_unet_config["image_size"] else: # The second upscaler unet's sample size is incorrectly specified # in the config and is instead hardcoded in source @@ -783,11 +776,11 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz def superres_create_unet_diffusers_config(original_unet_config): - attention_resolutions = parse_list(original_unet_config.attention_resolutions) - attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions] + attention_resolutions = parse_list(original_unet_config["attention_resolutions"]) + attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions] - channel_mult = parse_list(original_unet_config.channel_mult) - block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult] + channel_mult = parse_list(original_unet_config["channel_mult"]) + block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult] down_block_types = [] resolution = 1 @@ -795,7 +788,7 @@ def superres_create_unet_diffusers_config(original_unet_config): for i in range(len(block_out_channels)): if resolution in attention_resolutions: block_type = "SimpleCrossAttnDownBlock2D" - elif original_unet_config.resblock_updown: + elif original_unet_config["resblock_updown"]: block_type = "ResnetDownsampleBlock2D" else: block_type = "DownBlock2D" @@ -809,16 +802,16 @@ def superres_create_unet_diffusers_config(original_unet_config): for i in range(len(block_out_channels)): if resolution in attention_resolutions: block_type = "SimpleCrossAttnUpBlock2D" - elif original_unet_config.resblock_updown: + elif original_unet_config["resblock_updown"]: block_type = "ResnetUpsampleBlock2D" else: block_type = "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - head_dim = original_unet_config.num_head_channels + head_dim = original_unet_config["num_head_channels"] use_linear_projection = ( - original_unet_config.use_linear_in_transformer + original_unet_config["use_linear_in_transformer"] if "use_linear_in_transformer" in original_unet_config else False ) @@ -831,26 +824,26 @@ def superres_create_unet_diffusers_config(original_unet_config): projection_class_embeddings_input_dim = None if "num_classes" in original_unet_config: - if original_unet_config.num_classes == "sequential": + if original_unet_config["num_classes"] == "sequential": class_embed_type = "projection" assert "adm_in_channels" in original_unet_config - projection_class_embeddings_input_dim = original_unet_config.adm_in_channels + projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"] else: raise NotImplementedError( - f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}" + f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}" ) config = { - "in_channels": original_unet_config.in_channels, + "in_channels": original_unet_config["in_channels"], "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": tuple(original_unet_config.num_res_blocks), - "cross_attention_dim": original_unet_config.encoder_channels, + "layers_per_block": tuple(original_unet_config["num_res_blocks"]), + "cross_attention_dim": original_unet_config["encoder_channels"], "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - "out_channels": original_unet_config.out_channels, + "out_channels": original_unet_config["out_channels"], "up_block_types": tuple(up_block_types), "upcast_attention": False, # TODO: guessing "cross_attention_norm": "group_norm", @@ -858,7 +851,7 @@ def superres_create_unet_diffusers_config(original_unet_config): "act_fn": "gelu", } - if original_unet_config.use_scale_shift_norm: + if original_unet_config["use_scale_shift_norm"]: config["resnet_time_scale_shift"] = "scale_shift" return config diff --git a/scripts/convert_original_audioldm2_to_diffusers.py b/scripts/convert_original_audioldm2_to_diffusers.py index f0b22cb4b4..8c9878526a 100644 --- a/scripts/convert_original_audioldm2_to_diffusers.py +++ b/scripts/convert_original_audioldm2_to_diffusers.py @@ -19,6 +19,7 @@ import re from typing import List, Union import torch +import yaml from transformers import ( AutoFeatureExtractor, AutoTokenizer, @@ -45,7 +46,7 @@ from diffusers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.utils import is_omegaconf_available, is_safetensors_available +from diffusers.utils import is_safetensors_available from diffusers.utils.import_utils import BACKENDS_MAPPING @@ -212,41 +213,41 @@ def create_unet_diffusers_config(original_config, image_size: int): """ Creates a UNet config for diffusers based on the config of the original AudioLDM2 model. """ - unet_params = original_config.model.params.unet_config.params - vae_params = original_config.model.params.first_stage_config.params.ddconfig + unet_params = original_config["model"]["params"]["unet_config"]["params"] + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - cross_attention_dim = list(unet_params.context_dim) if "context_dim" in unet_params else block_out_channels + cross_attention_dim = list(unet_params["context_dim"]) if "context_dim" in unet_params else block_out_channels if len(cross_attention_dim) > 1: # require two or more cross-attention layers per-block, each of different dimension cross_attention_dim = [cross_attention_dim for _ in range(len(block_out_channels))] config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, - "out_channels": unet_params.out_channels, + "in_channels": unet_params["in_channels"], + "out_channels": unet_params["out_channels"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, - "transformer_layers_per_block": unet_params.transformer_depth, + "layers_per_block": unet_params["num_res_blocks"], + "transformer_layers_per_block": unet_params["transformer_depth"], "cross_attention_dim": tuple(cross_attention_dim), } @@ -259,24 +260,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): Creates a VAE config for diffusers based on the config of the original AudioLDM2 model. Compared to the original Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215 + scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215 config = { "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], "scaling_factor": float(scaling_factor), } return config @@ -285,9 +286,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular def create_diffusers_schedular(original_config): schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], beta_schedule="scaled_linear", ) return schedular @@ -692,17 +693,17 @@ def create_transformers_vocoder_config(original_config): """ Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. """ - vocoder_params = original_config.model.params.vocoder_config.params + vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"] config = { - "model_in_dim": vocoder_params.num_mels, - "sampling_rate": vocoder_params.sampling_rate, - "upsample_initial_channel": vocoder_params.upsample_initial_channel, - "upsample_rates": list(vocoder_params.upsample_rates), - "upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes), - "resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes), + "model_in_dim": vocoder_params["num_mels"], + "sampling_rate": vocoder_params["sampling_rate"], + "upsample_initial_channel": vocoder_params["upsample_initial_channel"], + "upsample_rates": list(vocoder_params["upsample_rates"]), + "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]), + "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]), "resblock_dilation_sizes": [ - list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes + list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"] ], "normalize_before": False, } @@ -876,11 +877,6 @@ def load_pipeline_from_original_AudioLDM2_ckpt( return: An AudioLDM2Pipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ - if not is_omegaconf_available(): - raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) - - from omegaconf import OmegaConf - if from_safetensors: if not is_safetensors_available(): raise ValueError(BACKENDS_MAPPING["safetensors"][1]) @@ -903,9 +899,8 @@ def load_pipeline_from_original_AudioLDM2_ckpt( if original_config_file is None: original_config = DEFAULT_CONFIG - original_config = OmegaConf.create(original_config) else: - original_config = OmegaConf.load(original_config_file) + original_config = yaml.safe_load(original_config_file) if image_size is not None: original_config["model"]["params"]["unet_config"]["params"]["image_size"] = image_size @@ -926,9 +921,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt( if prediction_type is None: prediction_type = "epsilon" - num_train_timesteps = original_config.model.params.timesteps - beta_start = original_config.model.params.linear_start - beta_end = original_config.model.params.linear_end + num_train_timesteps = original_config["model"]["params"]["timesteps"] + beta_start = original_config["model"]["params"]["linear_start"] + beta_end = original_config["model"]["params"]["linear_end"] scheduler = DDIMScheduler( beta_end=beta_end, @@ -1026,9 +1021,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt( # Convert the GPT2 encoder model: AudioLDM2 uses the same configuration as the original GPT2 base model gpt2_config = GPT2Config.from_pretrained("gpt2") gpt2_model = GPT2Model(gpt2_config) - gpt2_model.config.max_new_tokens = ( - original_config.model.params.cond_stage_config.crossattn_audiomae_generated.params.sequence_gen_length - ) + gpt2_model.config.max_new_tokens = original_config["model"]["params"]["cond_stage_config"][ + "crossattn_audiomae_generated" + ]["params"]["sequence_gen_length"] converted_gpt2_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.model.") gpt2_model.load_state_dict(converted_gpt2_checkpoint) diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py index 940c74e9cd..79e8fcc1af 100644 --- a/scripts/convert_original_audioldm_to_diffusers.py +++ b/scripts/convert_original_audioldm_to_diffusers.py @@ -18,6 +18,7 @@ import argparse import re import torch +import yaml from transformers import ( AutoTokenizer, ClapTextConfig, @@ -38,8 +39,6 @@ from diffusers import ( PNDMScheduler, UNet2DConditionModel, ) -from diffusers.utils import is_omegaconf_available -from diffusers.utils.import_utils import BACKENDS_MAPPING # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments @@ -215,45 +214,45 @@ def create_unet_diffusers_config(original_config, image_size: int): """ Creates a UNet config for diffusers based on the config of the original AudioLDM model. """ - unet_params = original_config.model.params.unet_config.params - vae_params = original_config.model.params.first_stage_config.params.ddconfig + unet_params = original_config["model"]["params"]["unet_config"]["params"] + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) cross_attention_dim = ( - unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels + unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels ) class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None projection_class_embeddings_input_dim = ( - unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None + unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None ) - class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None + class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, - "out_channels": unet_params.out_channels, + "in_channels": unet_params["in_channels"], + "out_channels": unet_params["out_channels"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, + "layers_per_block": unet_params["num_res_blocks"], "cross_attention_dim": cross_attention_dim, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, @@ -269,24 +268,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215 + scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215 config = { "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], "scaling_factor": float(scaling_factor), } return config @@ -295,9 +294,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular def create_diffusers_schedular(original_config): schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], beta_schedule="scaled_linear", ) return schedular @@ -668,17 +667,17 @@ def create_transformers_vocoder_config(original_config): """ Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. """ - vocoder_params = original_config.model.params.vocoder_config.params + vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"] config = { - "model_in_dim": vocoder_params.num_mels, - "sampling_rate": vocoder_params.sampling_rate, - "upsample_initial_channel": vocoder_params.upsample_initial_channel, - "upsample_rates": list(vocoder_params.upsample_rates), - "upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes), - "resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes), + "model_in_dim": vocoder_params["num_mels"], + "sampling_rate": vocoder_params["sampling_rate"], + "upsample_initial_channel": vocoder_params["upsample_initial_channel"], + "upsample_rates": list(vocoder_params["upsample_rates"]), + "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]), + "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]), "resblock_dilation_sizes": [ - list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes + list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"] ], "normalize_before": False, } @@ -818,11 +817,6 @@ def load_pipeline_from_original_audioldm_ckpt( return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ - if not is_omegaconf_available(): - raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) - - from omegaconf import OmegaConf - if from_safetensors: from safetensors import safe_open @@ -842,9 +836,8 @@ def load_pipeline_from_original_audioldm_ckpt( if original_config_file is None: original_config = DEFAULT_CONFIG - original_config = OmegaConf.create(original_config) else: - original_config = OmegaConf.load(original_config_file) + original_config = yaml.safe_load(original_config_file) if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels @@ -868,9 +861,9 @@ def load_pipeline_from_original_audioldm_ckpt( if image_size is None: image_size = 512 - num_train_timesteps = original_config.model.params.timesteps - beta_start = original_config.model.params.linear_start - beta_end = original_config.model.params.linear_end + num_train_timesteps = original_config["model"]["params"]["timesteps"] + beta_start = original_config["model"]["params"]["linear_start"] + beta_end = original_config["model"]["params"]["linear_end"] scheduler = DDIMScheduler( beta_end=beta_end, diff --git a/scripts/convert_original_musicldm_to_diffusers.py b/scripts/convert_original_musicldm_to_diffusers.py index bbc2fc96f8..b7da888a06 100644 --- a/scripts/convert_original_musicldm_to_diffusers.py +++ b/scripts/convert_original_musicldm_to_diffusers.py @@ -18,6 +18,7 @@ import argparse import re import torch +import yaml from transformers import ( AutoFeatureExtractor, AutoTokenizer, @@ -39,8 +40,6 @@ from diffusers import ( PNDMScheduler, UNet2DConditionModel, ) -from diffusers.utils import is_omegaconf_available -from diffusers.utils.import_utils import BACKENDS_MAPPING # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments @@ -212,45 +211,45 @@ def create_unet_diffusers_config(original_config, image_size: int): """ Creates a UNet config for diffusers based on the config of the original MusicLDM model. """ - unet_params = original_config.model.params.unet_config.params - vae_params = original_config.model.params.first_stage_config.params.ddconfig + unet_params = original_config["model"]["params"]["unet_config"]["params"] + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) cross_attention_dim = ( - unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels + unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels ) class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None projection_class_embeddings_input_dim = ( - unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None + unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None ) - class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None + class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, - "out_channels": unet_params.out_channels, + "in_channels": unet_params["in_channels"], + "out_channels": unet_params["out_channels"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, + "layers_per_block": unet_params["num_res_blocks"], "cross_attention_dim": cross_attention_dim, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, @@ -266,24 +265,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): Creates a VAE config for diffusers based on the config of the original MusicLDM model. Compared to the original Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215 + scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215 config = { "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], "scaling_factor": float(scaling_factor), } return config @@ -292,9 +291,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular def create_diffusers_schedular(original_config): schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], beta_schedule="scaled_linear", ) return schedular @@ -674,17 +673,17 @@ def create_transformers_vocoder_config(original_config): """ Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. """ - vocoder_params = original_config.model.params.vocoder_config.params + vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"] config = { - "model_in_dim": vocoder_params.num_mels, - "sampling_rate": vocoder_params.sampling_rate, - "upsample_initial_channel": vocoder_params.upsample_initial_channel, - "upsample_rates": list(vocoder_params.upsample_rates), - "upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes), - "resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes), + "model_in_dim": vocoder_params["num_mels"], + "sampling_rate": vocoder_params["sampling_rate"], + "upsample_initial_channel": vocoder_params["upsample_initial_channel"], + "upsample_rates": list(vocoder_params["upsample_rates"]), + "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]), + "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]), "resblock_dilation_sizes": [ - list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes + list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"] ], "normalize_before": False, } @@ -823,12 +822,6 @@ def load_pipeline_from_original_MusicLDM_ckpt( If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. return: An MusicLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ - - if not is_omegaconf_available(): - raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) - - from omegaconf import OmegaConf - if from_safetensors: from safetensors import safe_open @@ -848,9 +841,8 @@ def load_pipeline_from_original_MusicLDM_ckpt( if original_config_file is None: original_config = DEFAULT_CONFIG - original_config = OmegaConf.create(original_config) else: - original_config = OmegaConf.load(original_config_file) + original_config = yaml.safe_load(original_config_file) if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels @@ -874,9 +866,9 @@ def load_pipeline_from_original_MusicLDM_ckpt( if image_size is None: image_size = 512 - num_train_timesteps = original_config.model.params.timesteps - beta_start = original_config.model.params.linear_start - beta_end = original_config.model.params.linear_end + num_train_timesteps = original_config["model"]["params"]["timesteps"] + beta_start = original_config["model"]["params"]["linear_start"] + beta_end = original_config["model"]["params"]["linear_end"] scheduler = DDIMScheduler( beta_end=beta_end, diff --git a/scripts/convert_vae_pt_to_diffusers.py b/scripts/convert_vae_pt_to_diffusers.py index a8ba48bc00..a4f967c94f 100644 --- a/scripts/convert_vae_pt_to_diffusers.py +++ b/scripts/convert_vae_pt_to_diffusers.py @@ -3,7 +3,7 @@ import io import requests import torch -from omegaconf import OmegaConf +import yaml from diffusers import AutoencoderKL from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( @@ -126,7 +126,7 @@ def vae_pt_to_vae_diffuser( ) io_obj = io.BytesIO(r.content) - original_config = OmegaConf.load(io_obj) + original_config = yaml.safe_load(io_obj) image_size = 512 device = "cuda" if torch.cuda.is_available() else "cpu" if checkpoint_path.endswith("safetensors"): diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index 58ed2d93d5..7da6b40949 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -45,51 +45,45 @@ from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionSchedu from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings -try: - from omegaconf import OmegaConf -except ImportError: - raise ImportError( - "OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install" - " OmegaConf`." - ) - # vqvae model PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] def vqvae_model_from_original_config(original_config): - assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers." + assert ( + original_config["target"] in PORTED_VQVAES + ), f"{original_config['target']} has not yet been ported to diffusers." - original_config = original_config.params + original_config = original_config["params"] - original_encoder_config = original_config.encoder_config.params - original_decoder_config = original_config.decoder_config.params + original_encoder_config = original_config["encoder_config"]["params"] + original_decoder_config = original_config["decoder_config"]["params"] - in_channels = original_encoder_config.in_channels - out_channels = original_decoder_config.out_ch + in_channels = original_encoder_config["in_channels"] + out_channels = original_decoder_config["out_ch"] down_block_types = get_down_block_types(original_encoder_config) up_block_types = get_up_block_types(original_decoder_config) - assert original_encoder_config.ch == original_decoder_config.ch - assert original_encoder_config.ch_mult == original_decoder_config.ch_mult + assert original_encoder_config["ch"] == original_decoder_config["ch"] + assert original_encoder_config["ch_mult"] == original_decoder_config["ch_mult"] block_out_channels = tuple( - [original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult] + [original_encoder_config["ch"] * a_ch_mult for a_ch_mult in original_encoder_config["ch_mult"]] ) - assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks - layers_per_block = original_encoder_config.num_res_blocks + assert original_encoder_config["num_res_blocks"] == original_decoder_config["num_res_blocks"] + layers_per_block = original_encoder_config["num_res_blocks"] - assert original_encoder_config.z_channels == original_decoder_config.z_channels - latent_channels = original_encoder_config.z_channels + assert original_encoder_config["z_channels"] == original_decoder_config["z_channels"] + latent_channels = original_encoder_config["z_channels"] - num_vq_embeddings = original_config.n_embed + num_vq_embeddings = original_config["n_embed"] # Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion norm_num_groups = 32 - e_dim = original_config.embed_dim + e_dim = original_config["embed_dim"] model = VQModel( in_channels=in_channels, @@ -108,9 +102,9 @@ def vqvae_model_from_original_config(original_config): def get_down_block_types(original_encoder_config): - attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions) - num_resolutions = len(original_encoder_config.ch_mult) - resolution = coerce_resolution(original_encoder_config.resolution) + attn_resolutions = coerce_attn_resolutions(original_encoder_config["attn_resolutions"]) + num_resolutions = len(original_encoder_config["ch_mult"]) + resolution = coerce_resolution(original_encoder_config["resolution"]) curr_res = resolution down_block_types = [] @@ -129,9 +123,9 @@ def get_down_block_types(original_encoder_config): def get_up_block_types(original_decoder_config): - attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions) - num_resolutions = len(original_decoder_config.ch_mult) - resolution = coerce_resolution(original_decoder_config.resolution) + attn_resolutions = coerce_attn_resolutions(original_decoder_config["attn_resolutions"]) + num_resolutions = len(original_decoder_config["ch_mult"]) + resolution = coerce_resolution(original_decoder_config["resolution"]) curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution] up_block_types = [] @@ -150,7 +144,7 @@ def get_up_block_types(original_decoder_config): def coerce_attn_resolutions(attn_resolutions): - attn_resolutions = OmegaConf.to_object(attn_resolutions) + attn_resolutions = list(attn_resolutions) attn_resolutions_ = [] for ar in attn_resolutions: if isinstance(ar, (list, tuple)): @@ -161,7 +155,6 @@ def coerce_attn_resolutions(attn_resolutions): def coerce_resolution(resolution): - resolution = OmegaConf.to_object(resolution) if isinstance(resolution, int): resolution = [resolution, resolution] # H, W elif isinstance(resolution, (tuple, list)): @@ -472,18 +465,18 @@ def transformer_model_from_original_config( original_diffusion_config, original_transformer_config, original_content_embedding_config ): assert ( - original_diffusion_config.target in PORTED_DIFFUSIONS - ), f"{original_diffusion_config.target} has not yet been ported to diffusers." + original_diffusion_config["target"] in PORTED_DIFFUSIONS + ), f"{original_diffusion_config['target']} has not yet been ported to diffusers." assert ( - original_transformer_config.target in PORTED_TRANSFORMERS - ), f"{original_transformer_config.target} has not yet been ported to diffusers." + original_transformer_config["target"] in PORTED_TRANSFORMERS + ), f"{original_transformer_config['target']} has not yet been ported to diffusers." assert ( - original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS - ), f"{original_content_embedding_config.target} has not yet been ported to diffusers." + original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS + ), f"{original_content_embedding_config['target']} has not yet been ported to diffusers." - original_diffusion_config = original_diffusion_config.params - original_transformer_config = original_transformer_config.params - original_content_embedding_config = original_content_embedding_config.params + original_diffusion_config = original_diffusion_config["params"] + original_transformer_config = original_transformer_config["params"] + original_content_embedding_config = original_content_embedding_config["params"] inner_dim = original_transformer_config["n_embd"] @@ -689,13 +682,11 @@ def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_fee def read_config_file(filename): # The yaml file contains annotations that certain values should - # loaded as tuples. By default, OmegaConf will panic when reading - # these. Instead, we can manually read the yaml with the FullLoader and then - # construct the OmegaConf object. + # loaded as tuples. with open(filename) as f: original_config = yaml.load(f, FullLoader) - return OmegaConf.create(original_config) + return original_config # We take separate arguments for the vqvae because the ITHQ vqvae config file @@ -792,9 +783,9 @@ if __name__ == "__main__": original_config = read_config_file(args.original_config_file).model - diffusion_config = original_config.params.diffusion_config - transformer_config = original_config.params.diffusion_config.params.transformer_config - content_embedding_config = original_config.params.diffusion_config.params.content_emb_config + diffusion_config = original_config["params"]["diffusion_config"] + transformer_config = original_config["params"]["diffusion_config"]["params"]["transformer_config"] + content_embedding_config = original_config["params"]["diffusion_config"]["params"]["content_emb_config"] pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location) @@ -831,7 +822,7 @@ if __name__ == "__main__": # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate # model, so we pull them off the checkpoint before the checkpoint is deleted. - learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf + learnable_classifier_free_sampling_embeddings = diffusion_config["params"].learnable_cf if learnable_classifier_free_sampling_embeddings: learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] diff --git a/scripts/convert_zero123_to_diffusers.py b/scripts/convert_zero123_to_diffusers.py index bdcb2cd2e1..f016312b8b 100644 --- a/scripts/convert_zero123_to_diffusers.py +++ b/scripts/convert_zero123_to_diffusers.py @@ -14,6 +14,7 @@ $ python convert_zero123_to_diffusers.py \ import argparse import torch +import yaml from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device from pipeline_zero1to3 import CCProjection, Zero1to3StableDiffusionPipeline @@ -38,51 +39,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: - unet_params = original_config.model.params.control_stage_config.params + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] else: - if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: - unet_params = original_config.model.params.unet_config.params + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] else: - unet_params = original_config.model.params.network_config.params + unet_params = original_config["model"]["params"]["network_config"]["params"] - vae_params = original_config.model.params.first_stage_config.params.ddconfig + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - if unet_params.transformer_depth is not None: + if unet_params["transformer_depth"] is not None: transformer_layers_per_block = ( - unet_params.transformer_depth - if isinstance(unet_params.transformer_depth, int) - else list(unet_params.transformer_depth) + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) ) else: transformer_layers_per_block = 1 - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - head_dim = unet_params.num_heads if "num_heads" in unet_params else None + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim_mult = unet_params.model_channels // unet_params.num_head_channels - head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] class_embed_type = None addition_embed_type = None @@ -90,13 +94,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa projection_class_embeddings_input_dim = None context_dim = None - if unet_params.context_dim is not None: + if unet_params["context_dim"] is not None: context_dim = ( - unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] ) if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": + if unet_params["num_classes"] == "sequential": if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" @@ -104,16 +110,16 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa else: class_embed_type = "projection" assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] else: - raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params["num_classes"]}") config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, + "in_channels": unet_params["in_channels"], "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, + "layers_per_block": unet_params["num_res_blocks"], "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, @@ -125,9 +131,9 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa } if controlnet: - config["conditioning_channels"] = unet_params.hint_channels + config["conditioning_channels"] = unet_params["hint_channels"] else: - config["out_channels"] = unet_params.out_channels + config["out_channels"] = unet_params["out_channels"] config["up_block_types"] = tuple(up_block_types) return config @@ -487,22 +493,22 @@ def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], } return config @@ -679,18 +685,16 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex del ckpt torch.cuda.empty_cache() - from omegaconf import OmegaConf - - original_config = OmegaConf.load(original_config_file) - original_config.model.params.cond_stage_config.target.split(".")[-1] + original_config = yaml.safe_load(original_config_file) + original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] num_in_channels = 8 original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels prediction_type = "epsilon" image_size = 256 - num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000 - beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 - beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + beta_start = getattr(original_config["model"]["params"], "linear_start", None) or 0.02 + beta_end = getattr(original_config["model"]["params"], "linear_end", None) or 0.085 scheduler = DDIMScheduler( beta_end=beta_end, beta_schedule="scaled_linear", @@ -721,10 +725,10 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex if ( "model" in original_config - and "params" in original_config.model - and "scale_factor" in original_config.model.params + and "params" in original_config["model"] + and "scale_factor" in original_config["model"]["params"] ): - vae_scaling_factor = original_config.model.params.scale_factor + vae_scaling_factor = original_config["model"]["params"]["scale_factor"] else: vae_scaling_factor = 0.18215 # default SD scaling factor diff --git a/setup.py b/setup.py index 177c918d38..bb0f378170 100644 --- a/setup.py +++ b/setup.py @@ -110,7 +110,6 @@ _deps = [ "note_seq", "librosa", "numpy", - "omegaconf", "parameterized", "peft>=0.6.0", "protobuf>=3.20.3,<4", @@ -213,7 +212,6 @@ extras["test"] = deps_list( "invisible-watermark", "k-diffusion", "librosa", - "omegaconf", "parameterized", "pytest", "pytest-timeout", diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 03e8fe7a0a..e92a486bff 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -22,7 +22,6 @@ deps = { "note_seq": "note_seq", "librosa": "librosa", "numpy": "numpy", - "omegaconf": "omegaconf", "parameterized": "parameterized", "peft": "peft>=0.6.0", "protobuf": "protobuf>=3.20.3,<4", diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 039b6b910a..f2ac58cf93 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +from pathlib import Path from typing import Dict, Union import torch @@ -138,7 +138,7 @@ class IPAdapterMixin: logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") image_encoder = CLIPVisionModelWithProjection.from_pretrained( pretrained_model_name_or_path_or_dict, - subfolder=os.path.join(subfolder, "image_encoder"), + subfolder=Path(subfolder, "image_encoder").as_posix(), ).to(self.device, dtype=self.dtype) self.image_encoder = image_encoder self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 424e95f084..922c98b98b 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -14,6 +14,7 @@ import inspect import os from contextlib import nullcontext +from pathlib import Path from typing import Callable, Dict, List, Optional, Union import safetensors @@ -581,7 +582,6 @@ class LoraLoaderMixin: lora_config_kwargs = get_peft_kwargs( rank, network_alphas, text_encoder_lora_state_dict, is_unet=False ) - lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -961,8 +961,9 @@ class LoraLoaderMixin: else: weight_name = LORA_WEIGHT_NAME - save_function(state_dict, os.path.join(save_directory, weight_name)) - logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + save_path = Path(save_directory, weight_name).as_posix() + save_function(state_dict, save_path) + logger.info(f"Model weights saved in {save_path}") def unload_lora_weights(self): """ diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index b0fe790c22..0fb4637dab 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -13,11 +13,13 @@ # limitations under the License. import inspect +import math from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch +import torch.fft as fft from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor @@ -36,6 +38,7 @@ from ...schedulers import ( from ...utils import ( USE_PEFT_BACKEND, BaseOutput, + deprecate, logging, replace_example_docstring, scale_lora_layers, @@ -79,6 +82,71 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): return outputs +def _get_freeinit_freq_filter( + shape: Tuple[int, ...], + device: Union[str, torch.dtype], + filter_type: str, + order: float, + spatial_stop_frequency: float, + temporal_stop_frequency: float, +) -> torch.Tensor: + r"""Returns the FreeInit filter based on filter type and other input conditions.""" + + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + + if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: + return mask + + if filter_type == "butterworth": + + def retrieve_mask(x): + return 1 / (1 + (x / spatial_stop_frequency**2) ** order) + elif filter_type == "gaussian": + + def retrieve_mask(x): + return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) + elif filter_type == "ideal": + + def retrieve_mask(x): + return 1 if x <= spatial_stop_frequency * 2 else 0 + else: + raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") + + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ( + ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2 + + (2 * h / H - 1) ** 2 + + (2 * w / W - 1) ** 2 + ) + mask[..., t, h, w] = retrieve_mask(d_square) + + return mask.to(device) + + +def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor: + r"""Noise reinitialization.""" + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + HPF = 1 - LPF + x_freq_low = x_freq * LPF + noise_freq_high = noise_freq * HPF + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + + @dataclass class AnimateDiffPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] @@ -115,6 +183,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -442,6 +511,58 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def free_init_enabled(self): + return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None + + def enable_free_init( + self, + num_iters: int = 3, + use_fast_sampling: bool = False, + method: str = "butterworth", + order: int = 4, + spatial_stop_frequency: float = 0.25, + temporal_stop_frequency: float = 0.25, + generator: torch.Generator = None, + ): + """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. + + This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). + + Args: + num_iters (`int`, *optional*, defaults to `3`): + Number of FreeInit noise re-initialization iterations. + use_fast_sampling (`bool`, *optional*, defaults to `False`): + Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables + the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. + method (`str`, *optional*, defaults to `butterworth`): + Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the + FreeInit low pass filter. + order (`int`, *optional*, defaults to `4`): + Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour + whereas lower values lead to `gaussian` method behaviour. + spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in + the original implementation. + temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in + the original implementation. + generator (`torch.Generator`, *optional*, defaults to `0.25`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + FreeInit generation deterministic. + """ + self._free_init_num_iters = num_iters + self._free_init_use_fast_sampling = use_fast_sampling + self._free_init_method = method + self._free_init_order = order + self._free_init_spatial_stop_frequency = spatial_stop_frequency + self._free_init_temporal_stop_frequency = temporal_stop_frequency + self._free_init_generator = generator + + def disable_free_init(self): + """Disables the FreeInit mechanism if enabled.""" + self._free_init_num_iters = None + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -539,6 +660,185 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap latents = latents * self.scheduler.init_noise_sigma return latents + def _denoise_loop( + self, + timesteps, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + num_warmup_steps, + prompt_embeds, + negative_prompt_embeds, + latents, + cross_attention_kwargs, + added_cond_kwargs, + extra_step_kwargs, + callback, + callback_steps, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + ): + """Denoising loop for AnimateDiff.""" + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + return latents + + def _free_init_loop( + self, + height, + width, + num_frames, + num_channels_latents, + batch_size, + num_videos_per_prompt, + denoise_args, + device, + ): + """Denoising loop for AnimateDiff using FreeInit noise reinitialization technique.""" + + latents = denoise_args.get("latents") + prompt_embeds = denoise_args.get("prompt_embeds") + timesteps = denoise_args.get("timesteps") + num_inference_steps = denoise_args.get("num_inference_steps") + + latent_shape = ( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + free_init_filter_shape = ( + 1, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + free_init_freq_filter = _get_freeinit_freq_filter( + shape=free_init_filter_shape, + device=device, + filter_type=self._free_init_method, + order=self._free_init_order, + spatial_stop_frequency=self._free_init_spatial_stop_frequency, + temporal_stop_frequency=self._free_init_temporal_stop_frequency, + ) + + with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: + for i in range(self._free_init_num_iters): + # For the first FreeInit iteration, the original latent is used without modification. + # Subsequent iterations apply the noise reinitialization technique. + if i == 0: + initial_noise = latents.detach().clone() + else: + current_diffuse_timestep = ( + self.scheduler.config.num_train_timesteps - 1 + ) # diffuse to t=999 noise level + diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() + z_T = self.scheduler.add_noise( + original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + z_rand = randn_tensor( + shape=latent_shape, + generator=self._free_init_generator, + device=device, + dtype=torch.float32, + ) + latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) + latents = latents.to(prompt_embeds.dtype) + + # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) + if self._free_init_use_fast_sampling: + current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) + self.scheduler.set_timesteps(current_num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps}) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) + latents = self._denoise_loop(**denoise_args) + + free_init_progress_bar.update() + + return latents + + def _retrieve_video_frames(self, latents, output_type, return_dict): + """Helper function to handle latents to output conversion.""" + if output_type == "latent": + return AnimateDiffPipelineOutput(frames=latents) + + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -559,10 +859,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -603,25 +904,30 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + Examples: Returns: @@ -629,6 +935,23 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -637,9 +960,20 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -649,30 +983,26 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: @@ -680,12 +1010,13 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, num_videos_per_prompt, output_hidden_state ) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -703,55 +1034,47 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7 Add image embeds for IP-Adapter + + # 7. Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - # Denoising loop + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + denoise_args = { + "timesteps": timesteps, + "num_inference_steps": num_inference_steps, + "do_classifier_free_guidance": self.do_classifier_free_guidance, + "guidance_scale": guidance_scale, + "num_warmup_steps": num_warmup_steps, + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + "latents": latents, + "cross_attention_kwargs": self.cross_attention_kwargs, + "added_cond_kwargs": added_cond_kwargs, + "extra_step_kwargs": extra_step_kwargs, + "callback": callback, + "callback_steps": callback_steps, + "callback_on_step_end": callback_on_step_end, + "callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs, + } - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) - - # Post-processing - video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor + if self.free_init_enabled: + latents = self._free_init_loop( + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channels_latents, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + denoise_args=denoise_args, + device=device, + ) else: - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + latents = self._denoise_loop(**denoise_args) - # Offload all models + video = self._retrieve_video_frames(latents, output_type, return_dict) + + # 9. Offload all models self.maybe_free_model_hooks() - if not return_dict: - return (video,) - - return AnimateDiffPipelineOutput(frames=video) + return video diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 6bdc281ef8..6cd1658c59 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -603,15 +603,6 @@ class StableDiffusionControlNetPipeline( f" {negative_prompt_embeds.shape}." ) - # `prompt` needs more sophisticated handling when there are multiple - # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) - # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule @@ -633,7 +624,13 @@ class StableDiffusionControlNetPipeline( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings is not supported at the moment.") + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." @@ -659,7 +656,10 @@ class StableDiffusionControlNetPipeline( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings is not supported at the moment.") + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): @@ -906,7 +906,9 @@ class StableDiffusionControlNetPipeline( accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. + input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, + each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, + where a list of image lists can be passed to batch for each prompt and each ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -1105,6 +1107,11 @@ class StableDiffusionControlNetPipeline( elif isinstance(controlnet, MultiControlNetModel): images = [] + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + for image_ in image: image_ = self.prepare_image( image=image_, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 2083a6391c..6e00134591 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -1087,7 +1087,10 @@ class StableDiffusionControlNetImg2ImgPipeline( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 5aa23252b8..1ba06f811a 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -21,6 +21,7 @@ from typing import Dict, Optional, Union import requests import torch +import yaml from transformers import ( AutoFeatureExtractor, BertTokenizerFast, @@ -50,8 +51,7 @@ from ...schedulers import ( PNDMScheduler, UnCLIPScheduler, ) -from ...utils import is_accelerate_available, is_omegaconf_available, logging -from ...utils.import_utils import BACKENDS_MAPPING +from ...utils import is_accelerate_available, logging from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..paint_by_example import PaintByExampleImageEncoder from ..pipeline_utils import DiffusionPipeline @@ -237,51 +237,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: - unet_params = original_config.model.params.control_stage_config.params + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] else: - if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: - unet_params = original_config.model.params.unet_config.params + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] else: - unet_params = original_config.model.params.network_config.params + unet_params = original_config["model"]["params"]["network_config"]["params"] - vae_params = original_config.model.params.first_stage_config.params.ddconfig + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - if unet_params.transformer_depth is not None: + if unet_params["transformer_depth"] is not None: transformer_layers_per_block = ( - unet_params.transformer_depth - if isinstance(unet_params.transformer_depth, int) - else list(unet_params.transformer_depth) + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) ) else: transformer_layers_per_block = 1 - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - head_dim = unet_params.num_heads if "num_heads" in unet_params else None + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim_mult = unet_params.model_channels // unet_params.num_head_channels - head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] class_embed_type = None addition_embed_type = None @@ -289,13 +292,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa projection_class_embeddings_input_dim = None context_dim = None - if unet_params.context_dim is not None: + if unet_params["context_dim"] is not None: context_dim = ( - unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] ) if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": + if unet_params["num_classes"] == "sequential": if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" @@ -303,14 +308,14 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa else: class_embed_type = "projection" assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, + "in_channels": unet_params["in_channels"], "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, + "layers_per_block": unet_params["num_res_blocks"], "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, @@ -322,15 +327,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa } if "disable_self_attentions" in unet_params: - config["only_cross_attention"] = unet_params.disable_self_attentions + config["only_cross_attention"] = unet_params["disable_self_attentions"] - if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): - config["num_class_embeds"] = unet_params.num_classes + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] if controlnet: - config["conditioning_channels"] = unet_params.hint_channels + config["conditioning_channels"] = unet_params["hint_channels"] else: - config["out_channels"] = unet_params.out_channels + config["out_channels"] = unet_params["out_channels"] config["up_block_types"] = tuple(up_block_types) return config @@ -340,38 +345,38 @@ def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], } return config def create_diffusers_schedular(original_config): schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], beta_schedule="scaled_linear", ) return schedular def create_ldm_bert_config(original_config): - bert_params = original_config.model.params.cond_stage_config.params + bert_params = original_config["model"]["params"]["cond_stage_config"]["params"] config = LDMBertConfig( d_model=bert_params.n_embed, encoder_layers=bert_params.n_layer, @@ -1006,9 +1011,9 @@ def stable_unclip_image_encoder(original_config, local_files_only=False): encoders. """ - image_embedder_config = original_config.model.params.embedder_config + image_embedder_config = original_config["model"]["params"]["embedder_config"] - sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = image_embedder_config["target"] sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] if sd_clip_image_embedder_class == "ClipImageEmbedder": @@ -1047,8 +1052,8 @@ def stable_unclip_image_noising_components( If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. """ - noise_aug_config = original_config.model.params.noise_aug_config - noise_aug_class = noise_aug_config.target + noise_aug_config = original_config["model"]["params"]["noise_aug_config"] + noise_aug_class = noise_aug_config["target"] noise_aug_class = noise_aug_class.split(".")[-1] if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": @@ -1245,11 +1250,6 @@ def download_from_original_stable_diffusion_ckpt( if prediction_type == "v-prediction": prediction_type = "v_prediction" - if not is_omegaconf_available(): - raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) - - from omegaconf import OmegaConf - if isinstance(checkpoint_path_or_dict, str): if from_safetensors: from safetensors.torch import load_file as safe_load @@ -1317,19 +1317,22 @@ def download_from_original_stable_diffusion_ckpt( if config_url is not None: original_config_file = BytesIO(requests.get(config_url).content) + else: + with open(original_config_file, "r") as f: + original_config_file = f.read() - original_config = OmegaConf.load(original_config_file) + original_config = yaml.safe_load(original_config_file) # Convert the text model. if ( model_type is None - and "cond_stage_config" in original_config.model.params - and original_config.model.params.cond_stage_config is not None + and "cond_stage_config" in original_config["model"]["params"] + and original_config["model"]["params"]["cond_stage_config"] is not None ): - model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") - elif model_type is None and original_config.model.params.network_config is not None: - if original_config.model.params.network_config.params.context_dim == 2048: + elif model_type is None and original_config["model"]["params"]["network_config"] is not None: + if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048: model_type = "SDXL" else: model_type = "SDXL-Refiner" @@ -1354,7 +1357,7 @@ def download_from_original_stable_diffusion_ckpt( elif num_in_channels is None: num_in_channels = 4 - if "unet_config" in original_config.model.params: + if "unet_config" in original_config["model"]["params"]: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if ( @@ -1375,13 +1378,16 @@ def download_from_original_stable_diffusion_ckpt( if image_size is None: image_size = 512 - if controlnet is None and "control_stage_config" in original_config.model.params: + if controlnet is None and "control_stage_config" in original_config["model"]["params"]: path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" controlnet = convert_controlnet_checkpoint( checkpoint, original_config, path, image_size, upcast_attention, extract_ema ) - num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + if "timesteps" in original_config["model"]["params"]: + num_train_timesteps = original_config["model"]["params"]["timesteps"] + else: + num_train_timesteps = 1000 if model_type in ["SDXL", "SDXL-Refiner"]: scheduler_dict = { @@ -1400,8 +1406,15 @@ def download_from_original_stable_diffusion_ckpt( scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) scheduler_type = "euler" else: - beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 - beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + if "linear_start" in original_config["model"]["params"]: + beta_start = original_config["model"]["params"]["linear_start"] + else: + beta_start = 0.02 + + if "linear_end" in original_config["model"]["params"]: + beta_end = original_config["model"]["params"]["linear_end"] + else: + beta_end = 0.085 scheduler = DDIMScheduler( beta_end=beta_end, beta_schedule="scaled_linear", @@ -1435,7 +1448,7 @@ def download_from_original_stable_diffusion_ckpt( raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") if pipeline_class == StableDiffusionUpscalePipeline: - image_size = original_config.model.params.unet_config.params.image_size + image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config, image_size=image_size) @@ -1464,10 +1477,10 @@ def download_from_original_stable_diffusion_ckpt( if ( "model" in original_config - and "params" in original_config.model - and "scale_factor" in original_config.model.params + and "params" in original_config["model"] + and "scale_factor" in original_config["model"]["params"] ): - vae_scaling_factor = original_config.model.params.scale_factor + vae_scaling_factor = original_config["model"]["params"]["scale_factor"] else: vae_scaling_factor = 0.18215 # default SD scaling factor @@ -1803,11 +1816,6 @@ def download_controlnet_from_original_ckpt( use_linear_projection: Optional[bool] = None, cross_attention_dim: Optional[bool] = None, ) -> DiffusionPipeline: - if not is_omegaconf_available(): - raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) - - from omegaconf import OmegaConf - if from_safetensors: from safetensors import safe_open @@ -1827,12 +1835,12 @@ def download_controlnet_from_original_ckpt( while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - original_config = OmegaConf.load(original_config_file) + original_config = yaml.safe_load(original_config_file) if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - if "control_stage_config" not in original_config.model.params: + if "control_stage_config" not in original_config["model"]["params"]: raise ValueError("`control_stage_config` not present in original config") controlnet = convert_controlnet_checkpoint( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b06363cffd..b653d8e9f7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -858,7 +858,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ) # 4. Preprocess image - image = self.image_processor.preprocess(image).to(device) + image = self.image_processor.preprocess(image, height=height, width=width).to(device) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index e5360d37c6..56f7269130 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -52,6 +52,9 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) + if output_type == "np": + return np.stack(outputs) + return outputs diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9fb6fad3a2..596e5c4868 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,19 +1,28 @@ import contextlib import copy import random -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import torch from torchvision import transforms from .models import UNet2DConditionModel -from .utils import deprecate, is_transformers_available +from .utils import ( + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + deprecate, + is_peft_available, + is_transformers_available, +) if is_transformers_available(): import transformers +if is_peft_available(): + from peft import set_peft_model_state_dict + def set_seed(seed: int): """ @@ -112,6 +121,35 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: return lora_state_dict +def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + if not isinstance(model, list): + model = [model] + for m in model: + for param in m.parameters(): + # only upcast trainable parameters into fp32 + if param.requires_grad: + param.data = param.to(dtype) + + +def _set_state_dict_into_text_encoder( + lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module +): + """ + Sets the `lora_state_dict` into `text_encoder` coming from `transformers`. + + Args: + lora_state_dict: The state dictionary to be set. + prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`. + text_encoder: Where the `lora_state_dict` is to be set. + """ + + text_encoder_state_dict = { + f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix) + } + text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) + set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 54fde3e1f7..667f1fe5e2 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,7 +67,6 @@ from .import_utils import ( is_k_diffusion_version, is_librosa_available, is_note_seq_available, - is_omegaconf_available, is_onnx_available, is_peft_available, is_scipy_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3278af2f6..ac1565023b 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -223,12 +223,6 @@ try: except importlib_metadata.PackageNotFoundError: _wandb_available = False -_omegaconf_available = importlib.util.find_spec("omegaconf") is not None -try: - _omegaconf_version = importlib_metadata.version("omegaconf") - logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}") -except importlib_metadata.PackageNotFoundError: - _omegaconf_available = False _tensorboard_available = importlib.util.find_spec("tensorboard") try: @@ -345,10 +339,6 @@ def is_wandb_available(): return _wandb_available -def is_omegaconf_available(): - return _omegaconf_available - - def is_tensorboard_available(): return _tensorboard_available @@ -449,12 +439,6 @@ WANDB_IMPORT_ERROR = """ install wandb` """ -# docstyle-ignore -OMEGACONF_IMPORT_ERROR = """ -{0} requires the omegaconf library but it was not found in your environment. You can install it with pip: `pip -install omegaconf` -""" - # docstyle-ignore TENSORBOARD_IMPORT_ERROR = """ {0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip @@ -506,7 +490,6 @@ BACKENDS_MAPPING = OrderedDict( ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), - ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index df1a4fc420..86e31eb688 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -137,7 +137,7 @@ def get_tests_dir(append_path=None): tests_dir = os.path.dirname(tests_dir) if append_path: - return os.path.join(tests_dir, append_path) + return Path(tests_dir, append_path).as_posix() else: return tests_dir @@ -335,10 +335,9 @@ def require_python39_or_higher(test_case): def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: if isinstance(arry, str): - # local_path = "/home/patrick_huggingface_co/" if local_path is not None: # local_path can be passed to correct images of tests - return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]])) + return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix() elif arry.startswith("http://") or arry.startswith("https://"): response = requests.get(arry) response.raise_for_status() @@ -520,10 +519,10 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non def load_hf_numpy(path) -> np.ndarray: - if not path.startswith("http://") or path.startswith("https://"): - path = os.path.join( - "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path) - ) + base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main" + + if not path.startswith("http://") and not path.startswith("https://"): + path = os.path.join(base_url, urllib.parse.quote(path)) return load_numpy(path) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 633ed9fc23..44cb730a95 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -38,8 +38,8 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "generator", "latents", "return_dict", - "callback", - "callback_steps", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", ] ) @@ -233,6 +233,43 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device) pipe(**inputs) + def test_free_init(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + free_init_generator = torch.Generator(device=torch_device).manual_seed(0) + pipe.enable_free_init( + num_iters=2, + use_fast_sampling=True, + method="butterworth", + order=4, + spatial_stop_frequency=0.25, + temporal_stop_frequency=0.25, + generator=free_init_generator, + ) + inputs_enable_free_init = self.get_dummy_inputs(torch_device) + frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0] + + pipe.disable_free_init() + inputs_disable_free_init = self.get_dummy_inputs(torch_device) + frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max() + self.assertGreater( + sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results" + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeInit should lead to results similar to the default pipeline results", + ) + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index ce86933430..c034a9b68b 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -460,6 +460,33 @@ class StableDiffusionMultiControlNetPipelineFastTests( except NotImplementedError: pass + def test_inference_multiple_prompt_input(self): + device = "cpu" + + components = self.get_dummy_components() + sd_pipe = StableDiffusionControlNetPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"], inputs["prompt"]] + inputs["image"] = [inputs["image"], inputs["image"]] + output = sd_pipe(**inputs) + image = output.images + + assert image.shape == (2, 64, 64, 3) + + image_1, image_2 = image + # make sure that the outputs are different + assert np.sum(np.abs(image_1 - image_2)) > 1e-3 + + # multiple prompts, single image conditioning + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"], inputs["prompt"]] + output_1 = sd_pipe(**inputs) + + assert np.abs(image - output_1.images).max() < 1e-3 + class StableDiffusionMultiControlNetOneModelPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py index 158385db02..e9758bf6eb 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py @@ -64,7 +64,7 @@ class StableDiffusionXLKPipelineIntegrationTests(unittest.TestCase): assert image.shape == (1, 512, 512, 3) expected_slice = np.array( - [0.79804534, 0.7981539, 0.8019961, 0.7936565, 0.7892033, 0.7914713, 0.7792827, 0.77754563, 0.7836789] + [0.79600024, 0.796546, 0.80682373, 0.79428387, 0.7905743, 0.8008807, 0.786183, 0.7835959, 0.797892] ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -97,7 +97,7 @@ class StableDiffusionXLKPipelineIntegrationTests(unittest.TestCase): assert image.shape == (1, 512, 512, 3) expected_slice = np.array( - [0.9704869, 0.9714559, 0.9693254, 0.96892524, 0.9685236, 0.9659081, 0.9666761, 0.9619067, 0.961759] + [0.9389532, 0.9408587, 0.9394901, 0.939082, 0.9402114, 0.9382007, 0.93737566, 0.9346897, 0.9324472] ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index 1197842436..871266fb9c 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -185,6 +185,23 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa def test_inference_batch_consistent(self): pass + def test_np_output_type(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + inputs["output_type"] = "np" + output = pipe(**inputs).frames + self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(len(output.shape), 5) + def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4): components = self.get_dummy_components() pipe = self.pipeline_class(**components)