diff --git a/scripts/convert_sana_controlnet_to_diffusers.py b/scripts/convert_sana_controlnet_to_diffusers.py index 743d6ea315..bc1eb32788 100644 --- a/scripts/convert_sana_controlnet_to_diffusers.py +++ b/scripts/convert_sana_controlnet_to_diffusers.py @@ -16,6 +16,7 @@ from diffusers.utils.import_utils import is_accelerate_available CTX = init_empty_weights if is_accelerate_available else nullcontext + def main(args): file_path = args.orig_ckpt_path @@ -182,7 +183,7 @@ if __name__ == "__main__": "--model_type", default="SanaMS_1600M_P1_ControlNet_D7", type=str, - choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"], + choices=["SanaMS_1600M_P1_ControlNet_D7", "SanaMS_600M_P1_ControlNet_D7"], ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") diff --git a/src/diffusers/models/controlnets/controlnet_sana.py b/src/diffusers/models/controlnets/controlnet_sana.py index f3055ad7e5..d9b0d73c24 100644 --- a/src/diffusers/models/controlnets/controlnet_sana.py +++ b/src/diffusers/models/controlnets/controlnet_sana.py @@ -231,7 +231,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): post_patch_height, post_patch_width = height // p, width // p hidden_states = self.patch_embed(hidden_states) - hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond)) + hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype))) timestep, embedded_timestep = self.time_embed( timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index 5faa6f0a4a..71479c4e60 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -16,7 +16,6 @@ import html import inspect import re import urllib.parse as ul -import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -111,17 +110,30 @@ EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch - >>> from diffusers import SanaPipeline + >>> from diffusers import SanaControlNetModel, SanaControlNetPipeline + >>> from diffusers.utils import load_image - >>> pipe = SanaPipeline.from_pretrained( - ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 + >>> controlnet = SanaControlNetModel.from_pretrained( + ... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = SanaControlNetPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_600M_1024px_diffusers", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... controlnet=controlnet, ... ) >>> pipe.to("cuda") + >>> pipe.vae.to(torch.bfloat16) >>> pipe.text_encoder.to(torch.bfloat16) - >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) - - >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] - >>> image[0].save("output.png") + >>> cond_image = load_image( + ... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png" + ... ) + >>> prompt = 'a cat with a neon sign that says "Sana"' + >>> image = pipe( + ... prompt, + ... control_image=cond_image, + ... ).images[0] + >>> image.save("output.png") ``` """ @@ -936,7 +948,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=self.dtype, + dtype=self.vae.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=False, ) @@ -1016,8 +1028,6 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: noise_pred = noise_pred.chunk(2, dim=1)[0] - else: - noise_pred = noise_pred # compute previous image: x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]