1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

bug fixes

This commit is contained in:
ishan-modi
2025-03-13 18:07:04 +05:30
parent dfc396e384
commit 009937be4b
3 changed files with 24 additions and 13 deletions

View File

@@ -16,6 +16,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
def main(args):
file_path = args.orig_ckpt_path
@@ -182,7 +183,7 @@ if __name__ == "__main__":
"--model_type",
default="SanaMS_1600M_P1_ControlNet_D7",
type=str,
choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"],
choices=["SanaMS_1600M_P1_ControlNet_D7", "SanaMS_600M_P1_ControlNet_D7"],
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

View File

@@ -231,7 +231,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
post_patch_height, post_patch_width = height // p, width // p
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond))
hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype)))
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype

View File

@@ -16,7 +16,6 @@ import html
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -111,17 +110,30 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import SanaPipeline
>>> from diffusers import SanaControlNetModel, SanaControlNetPipeline
>>> from diffusers.utils import load_image
>>> pipe = SanaPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
>>> controlnet = SanaControlNetModel.from_pretrained(
... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe = SanaControlNetPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_600M_1024px_diffusers",
... variant="fp16",
... torch_dtype=torch.float16,
... controlnet=controlnet,
... )
>>> pipe.to("cuda")
>>> pipe.vae.to(torch.bfloat16)
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
>>> image[0].save("output.png")
>>> cond_image = load_image(
... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
... )
>>> prompt = 'a cat with a neon sign that says "Sana"'
>>> image = pipe(
... prompt,
... control_image=cond_image,
... ).images[0]
>>> image.save("output.png")
```
"""
@@ -936,7 +948,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.dtype,
dtype=self.vae.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=False,
)
@@ -1016,8 +1028,6 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]