From dea5de50ae4bc109be68df432cac16b42cbb23e0 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Mon, 7 Apr 2025 17:33:21 +0530 Subject: [PATCH] addressed PR comments --- docs/source/en/api/models/controlnet_sana.md | 11 ------- .../en/api/pipelines/controlnet_sana.md | 30 ------------------- .../sana/pipeline_sana_controlnet.py | 25 ++++++---------- 3 files changed, 9 insertions(+), 57 deletions(-) diff --git a/docs/source/en/api/models/controlnet_sana.md b/docs/source/en/api/models/controlnet_sana.md index 2504b8c58e..f0426308f7 100644 --- a/docs/source/en/api/models/controlnet_sana.md +++ b/docs/source/en/api/models/controlnet_sana.md @@ -21,17 +21,6 @@ The abstract from the paper is: This model was contributed by [ishan24](https://huggingface.co/ishan24). ❤️ The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile. -## Loading from the original format -By default the [`SanaControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`] -```py -from diffusers import SanaControlNetModel -import torch - -controlnet = SanaControlNetModel.from_pretrained( - "ishan24/Sana_600M_1024px_ControlNet_diffusers", -) -``` - ## SanaControlNetModel [[autodoc]] SanaControlNetModel diff --git a/docs/source/en/api/pipelines/controlnet_sana.md b/docs/source/en/api/pipelines/controlnet_sana.md index 952e9b499d..67ec882d68 100644 --- a/docs/source/en/api/pipelines/controlnet_sana.md +++ b/docs/source/en/api/pipelines/controlnet_sana.md @@ -27,36 +27,6 @@ The abstract from the paper is: This pipeline was contributed by [ishan24](https://huggingface.co/ishan24). ❤️ The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile. -## Loading from the original format -```py -import torch -from diffusers import SanaControlNetModel, SanaControlNetPipeline -from diffusers.utils import load_image - -controlnet = SanaControlNetModel.from_pretrained( - "ishan24/Sana_600M_1024px_ControlNet_diffusers", - torch_dtype=torch.float16 -) - -pipe = SanaControlNetPipeline.from_pretrained( - "Efficient-Large-Model/Sana_600M_1024px_diffusers", - variant="fp16", - controlnet=controlnet, - torch_dtype={'default': torch.bfloat16, 'transformer': torch.float16}, -) -pipe.to('cuda') - -cond_image = load_image( - "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png" -) -prompt='a cat with a neon sign that says "Sana"' -image = pipe( - prompt, - control_image=cond_image, -).images[0] -image.save("sana.png") -``` - ## SanaControlNetPipeline [[autodoc]] SanaControlNetPipeline - all diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index 2071fbc0c5..d5d552ea27 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -115,17 +115,15 @@ EXAMPLE_DOC_STRING = """ >>> from diffusers.utils import load_image >>> controlnet = SanaControlNetModel.from_pretrained( - ... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16 + ... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.float16 ... ) >>> pipe = SanaControlNetPipeline.from_pretrained( ... "Efficient-Large-Model/Sana_600M_1024px_diffusers", ... variant="fp16", - ... torch_dtype=torch.float16, ... controlnet=controlnet, + ... torch_dtype={"default": torch.bfloat16, "transformer": torch.float16}, ... ) >>> pipe.to("cuda") - >>> pipe.vae.to(torch.bfloat16) - >>> pipe.text_encoder.to(torch.bfloat16) >>> cond_image = load_image( ... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png" ... ) @@ -952,7 +950,6 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): height, width = control_image.shape[-2:] control_image = self.vae.encode(control_image).latent - control_image = control_image.to(self.vae.dtype) control_image = control_image * self.vae.config.scaling_factor else: @@ -983,6 +980,8 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + controlnet_dtype = self.controlnet.dtype + transformer_dtype = self.transformer.dtype with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -994,12 +993,9 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): timestep = t.expand(latent_model_input.shape[0]) # controlnet(s) inference - latent_model_input = latent_model_input.to(dtype=self.controlnet.dtype) - prompt_embeds = prompt_embeds.to(dtype=self.controlnet.dtype) - control_image = control_image.to(dtype=self.controlnet.dtype) controlnet_block_samples = self.controlnet( - latent_model_input, - encoder_hidden_states=prompt_embeds, + latent_model_input.to(dtype=controlnet_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype), encoder_attention_mask=prompt_attention_mask, timestep=timestep, return_dict=False, @@ -1009,17 +1005,14 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): )[0] # predict noise model_output - latent_model_input = latent_model_input.to(dtype=self.transformer.dtype) - prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) - controlnet_block_samples = controlnet_block_samples.to(dtype=self.transformer.dtype) noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), encoder_attention_mask=prompt_attention_mask, timestep=timestep, return_dict=False, attention_kwargs=self.attention_kwargs, - controlnet_block_samples=controlnet_block_samples, + controlnet_block_samples=controlnet_block_samples.to(dtype=transformer_dtype), )[0] noise_pred = noise_pred.float()