mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
addressed PR comments
This commit is contained in:
@@ -21,17 +21,6 @@ The abstract from the paper is:
|
||||
This model was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
|
||||
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
|
||||
|
||||
## Loading from the original format
|
||||
By default the [`SanaControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`]
|
||||
```py
|
||||
from diffusers import SanaControlNetModel
|
||||
import torch
|
||||
|
||||
controlnet = SanaControlNetModel.from_pretrained(
|
||||
"ishan24/Sana_600M_1024px_ControlNet_diffusers",
|
||||
)
|
||||
```
|
||||
|
||||
## SanaControlNetModel
|
||||
[[autodoc]] SanaControlNetModel
|
||||
|
||||
|
||||
@@ -27,36 +27,6 @@ The abstract from the paper is:
|
||||
This pipeline was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
|
||||
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
|
||||
|
||||
## Loading from the original format
|
||||
```py
|
||||
import torch
|
||||
from diffusers import SanaControlNetModel, SanaControlNetPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
controlnet = SanaControlNetModel.from_pretrained(
|
||||
"ishan24/Sana_600M_1024px_ControlNet_diffusers",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipe = SanaControlNetPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
|
||||
variant="fp16",
|
||||
controlnet=controlnet,
|
||||
torch_dtype={'default': torch.bfloat16, 'transformer': torch.float16},
|
||||
)
|
||||
pipe.to('cuda')
|
||||
|
||||
cond_image = load_image(
|
||||
"https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
|
||||
)
|
||||
prompt='a cat with a neon sign that says "Sana"'
|
||||
image = pipe(
|
||||
prompt,
|
||||
control_image=cond_image,
|
||||
).images[0]
|
||||
image.save("sana.png")
|
||||
```
|
||||
|
||||
## SanaControlNetPipeline
|
||||
[[autodoc]] SanaControlNetPipeline
|
||||
- all
|
||||
|
||||
@@ -115,17 +115,15 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> controlnet = SanaControlNetModel.from_pretrained(
|
||||
... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
|
||||
... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = SanaControlNetPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_600M_1024px_diffusers",
|
||||
... variant="fp16",
|
||||
... torch_dtype=torch.float16,
|
||||
... controlnet=controlnet,
|
||||
... torch_dtype={"default": torch.bfloat16, "transformer": torch.float16},
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.vae.to(torch.bfloat16)
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> cond_image = load_image(
|
||||
... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
|
||||
... )
|
||||
@@ -952,7 +950,6 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
control_image = self.vae.encode(control_image).latent
|
||||
control_image = control_image.to(self.vae.dtype)
|
||||
control_image = control_image * self.vae.config.scaling_factor
|
||||
|
||||
else:
|
||||
@@ -983,6 +980,8 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
controlnet_dtype = self.controlnet.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
@@ -994,12 +993,9 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# controlnet(s) inference
|
||||
latent_model_input = latent_model_input.to(dtype=self.controlnet.dtype)
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.controlnet.dtype)
|
||||
control_image = control_image.to(dtype=self.controlnet.dtype)
|
||||
controlnet_block_samples = self.controlnet(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
latent_model_input.to(dtype=controlnet_dtype),
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
@@ -1009,17 +1005,14 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
# predict noise model_output
|
||||
latent_model_input = latent_model_input.to(dtype=self.transformer.dtype)
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
|
||||
controlnet_block_samples = controlnet_block_samples.to(dtype=self.transformer.dtype)
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
latent_model_input.to(dtype=transformer_dtype),
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_block_samples=controlnet_block_samples.to(dtype=transformer_dtype),
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user