1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

addressed PR comments

This commit is contained in:
ishan-modi
2025-04-07 17:33:21 +05:30
parent 3d085a2b95
commit dea5de50ae
3 changed files with 9 additions and 57 deletions

View File

@@ -21,17 +21,6 @@ The abstract from the paper is:
This model was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
## Loading from the original format
By default the [`SanaControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`]
```py
from diffusers import SanaControlNetModel
import torch
controlnet = SanaControlNetModel.from_pretrained(
"ishan24/Sana_600M_1024px_ControlNet_diffusers",
)
```
## SanaControlNetModel
[[autodoc]] SanaControlNetModel

View File

@@ -27,36 +27,6 @@ The abstract from the paper is:
This pipeline was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
## Loading from the original format
```py
import torch
from diffusers import SanaControlNetModel, SanaControlNetPipeline
from diffusers.utils import load_image
controlnet = SanaControlNetModel.from_pretrained(
"ishan24/Sana_600M_1024px_ControlNet_diffusers",
torch_dtype=torch.float16
)
pipe = SanaControlNetPipeline.from_pretrained(
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
variant="fp16",
controlnet=controlnet,
torch_dtype={'default': torch.bfloat16, 'transformer': torch.float16},
)
pipe.to('cuda')
cond_image = load_image(
"https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
)
prompt='a cat with a neon sign that says "Sana"'
image = pipe(
prompt,
control_image=cond_image,
).images[0]
image.save("sana.png")
```
## SanaControlNetPipeline
[[autodoc]] SanaControlNetPipeline
- all

View File

@@ -115,17 +115,15 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers.utils import load_image
>>> controlnet = SanaControlNetModel.from_pretrained(
... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.float16
... )
>>> pipe = SanaControlNetPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_600M_1024px_diffusers",
... variant="fp16",
... torch_dtype=torch.float16,
... controlnet=controlnet,
... torch_dtype={"default": torch.bfloat16, "transformer": torch.float16},
... )
>>> pipe.to("cuda")
>>> pipe.vae.to(torch.bfloat16)
>>> pipe.text_encoder.to(torch.bfloat16)
>>> cond_image = load_image(
... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
... )
@@ -952,7 +950,6 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
height, width = control_image.shape[-2:]
control_image = self.vae.encode(control_image).latent
control_image = control_image.to(self.vae.dtype)
control_image = control_image * self.vae.config.scaling_factor
else:
@@ -983,6 +980,8 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
controlnet_dtype = self.controlnet.dtype
transformer_dtype = self.transformer.dtype
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
@@ -994,12 +993,9 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# controlnet(s) inference
latent_model_input = latent_model_input.to(dtype=self.controlnet.dtype)
prompt_embeds = prompt_embeds.to(dtype=self.controlnet.dtype)
control_image = control_image.to(dtype=self.controlnet.dtype)
controlnet_block_samples = self.controlnet(
latent_model_input,
encoder_hidden_states=prompt_embeds,
latent_model_input.to(dtype=controlnet_dtype),
encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype),
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
return_dict=False,
@@ -1009,17 +1005,14 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
)[0]
# predict noise model_output
latent_model_input = latent_model_input.to(dtype=self.transformer.dtype)
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
controlnet_block_samples = controlnet_block_samples.to(dtype=self.transformer.dtype)
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
latent_model_input.to(dtype=transformer_dtype),
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_block_samples=controlnet_block_samples.to(dtype=transformer_dtype),
)[0]
noise_pred = noise_pred.float()