mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
bug fixes
This commit is contained in:
@@ -16,6 +16,7 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
|
||||
def main(args):
|
||||
file_path = args.orig_ckpt_path
|
||||
|
||||
@@ -182,7 +183,7 @@ if __name__ == "__main__":
|
||||
"--model_type",
|
||||
default="SanaMS_1600M_P1_ControlNet_D7",
|
||||
type=str,
|
||||
choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"],
|
||||
choices=["SanaMS_1600M_P1_ControlNet_D7", "SanaMS_600M_P1_ControlNet_D7"],
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
@@ -231,7 +231,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
post_patch_height, post_patch_width = height // p, width // p
|
||||
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond))
|
||||
hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype)))
|
||||
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
|
||||
@@ -16,7 +16,6 @@ import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -111,17 +110,30 @@ EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import SanaPipeline
|
||||
>>> from diffusers import SanaControlNetModel, SanaControlNetPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = SanaPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
|
||||
>>> controlnet = SanaControlNetModel.from_pretrained(
|
||||
... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = SanaControlNetPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_600M_1024px_diffusers",
|
||||
... variant="fp16",
|
||||
... torch_dtype=torch.float16,
|
||||
... controlnet=controlnet,
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.vae.to(torch.bfloat16)
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
>>> cond_image = load_image(
|
||||
... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
|
||||
... )
|
||||
>>> prompt = 'a cat with a neon sign that says "Sana"'
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... control_image=cond_image,
|
||||
... ).images[0]
|
||||
>>> image.save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -936,7 +948,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
dtype=self.vae.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=False,
|
||||
)
|
||||
@@ -1016,8 +1028,6 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
# learned sigma
|
||||
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
Reference in New Issue
Block a user