mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Remove negative_* from SDXL callback (#10203)
* Remove `negative_*` from SDXL callback * Change example and add XL version
This commit is contained in:
@@ -241,27 +241,15 @@ from diffusers import StableDiffusionPipeline
|
||||
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
import torch
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
|
||||
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
pipeline.safety_checker = None
|
||||
pipeline.requires_safety_checker = False
|
||||
|
||||
|
||||
class SDPromptScheduleCallback(PipelineCallback):
|
||||
class SDPromptSchedulingCallback(PipelineCallback):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
cutoff_step_ratio=1.0,
|
||||
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
cutoff_step_ratio=None,
|
||||
cutoff_step_index=None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -275,6 +263,10 @@ class SDPromptScheduleCallback(PipelineCallback):
|
||||
) -> Dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
if isinstance(self.config.encoded_prompt, tuple):
|
||||
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
|
||||
else:
|
||||
prompt_embeds = self.config.encoded_prompt
|
||||
|
||||
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
||||
cutoff_step = (
|
||||
@@ -284,34 +276,164 @@ class SDPromptScheduleCallback(PipelineCallback):
|
||||
)
|
||||
|
||||
if step_index == cutoff_step:
|
||||
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
|
||||
prompt=self.config.prompt,
|
||||
negative_prompt=self.config.negative_prompt,
|
||||
device=pipeline._execution_device,
|
||||
num_images_per_prompt=self.config.num_images_per_prompt,
|
||||
do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
|
||||
)
|
||||
if pipeline.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
||||
return callback_kwargs
|
||||
|
||||
|
||||
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
pipeline.safety_checker = None
|
||||
pipeline.requires_safety_checker = False
|
||||
|
||||
callback = MultiPipelineCallbacks(
|
||||
[
|
||||
SDPromptScheduleCallback(
|
||||
prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
|
||||
negative_prompt="Deformed, ugly, bad anatomy",
|
||||
cutoff_step_ratio=0.25,
|
||||
)
|
||||
SDPromptSchedulingCallback(
|
||||
encoded_prompt=pipeline.encode_prompt(
|
||||
prompt=f"prompt {index}",
|
||||
negative_prompt=f"negative prompt {index}",
|
||||
device=pipeline._execution_device,
|
||||
num_images_per_prompt=1,
|
||||
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
|
||||
do_classifier_free_guidance=True,
|
||||
),
|
||||
cutoff_step_index=index,
|
||||
) for index in range(1, 20)
|
||||
]
|
||||
)
|
||||
|
||||
image = pipeline(
|
||||
prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
|
||||
negative_prompt="Deformed, ugly, bad anatomy",
|
||||
prompt="prompt"
|
||||
negative_prompt="negative prompt",
|
||||
callback_on_step_end=callback,
|
||||
callback_on_step_end_tensor_inputs=["prompt_embeds"],
|
||||
).images[0]
|
||||
torch.cuda.empty_cache()
|
||||
image.save('image.png')
|
||||
```
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
import torch
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
|
||||
class SDXLPromptSchedulingCallback(PipelineCallback):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
cutoff_step_ratio=None,
|
||||
cutoff_step_index=None,
|
||||
):
|
||||
super().__init__(
|
||||
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
|
||||
)
|
||||
|
||||
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
|
||||
|
||||
def callback_fn(
|
||||
self, pipeline, step_index, timestep, callback_kwargs
|
||||
) -> Dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
if isinstance(self.config.encoded_prompt, tuple):
|
||||
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
|
||||
else:
|
||||
prompt_embeds = self.config.encoded_prompt
|
||||
if isinstance(self.config.add_text_embeds, tuple):
|
||||
add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds
|
||||
else:
|
||||
add_text_embeds = self.config.add_text_embeds
|
||||
if isinstance(self.config.add_time_ids, tuple):
|
||||
add_time_ids, negative_add_time_ids = self.config.add_time_ids
|
||||
else:
|
||||
add_time_ids = self.config.add_time_ids
|
||||
|
||||
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
||||
cutoff_step = (
|
||||
cutoff_step_index
|
||||
if cutoff_step_index is not None
|
||||
else int(pipeline.num_timesteps * cutoff_step_ratio)
|
||||
)
|
||||
|
||||
if step_index == cutoff_step:
|
||||
if pipeline.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
|
||||
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
||||
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
|
||||
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
|
||||
return callback_kwargs
|
||||
|
||||
|
||||
pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
|
||||
callbacks = []
|
||||
for index in range(1, 20):
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipeline.encode_prompt(
|
||||
prompt=f"prompt {index}",
|
||||
negative_prompt=f"prompt {index}",
|
||||
device=pipeline._execution_device,
|
||||
num_images_per_prompt=1,
|
||||
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
add_time_ids = pipeline._get_add_time_ids(
|
||||
(1024, 1024),
|
||||
(0, 0),
|
||||
(1024, 1024),
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
negative_add_time_ids = pipeline._get_add_time_ids(
|
||||
(1024, 1024),
|
||||
(0, 0),
|
||||
(1024, 1024),
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
callbacks.append(
|
||||
SDXLPromptSchedulingCallback(
|
||||
encoded_prompt=(prompt_embeds, negative_prompt_embeds),
|
||||
add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds),
|
||||
add_time_ids=(add_time_ids, negative_add_time_ids),
|
||||
cutoff_step_index=index,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
callback = MultiPipelineCallbacks(callbacks)
|
||||
|
||||
image = pipeline(
|
||||
prompt="prompt",
|
||||
negative_prompt="negative prompt",
|
||||
callback_on_step_end=callback,
|
||||
callback_on_step_end_tensor_inputs=[
|
||||
"prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
],
|
||||
).images[0]
|
||||
```
|
||||
|
||||
@@ -237,11 +237,8 @@ class StableDiffusionXLPipeline(
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
"negative_pooled_prompt_embeds",
|
||||
"negative_add_time_ids",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -1243,13 +1240,8 @@ class StableDiffusionXLPipeline(
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -257,11 +257,8 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
"negative_pooled_prompt_embeds",
|
||||
"add_neg_time_ids",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -1438,13 +1435,8 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -285,11 +285,8 @@ class StableDiffusionXLInpaintPipeline(
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
"negative_pooled_prompt_embeds",
|
||||
"add_neg_time_ids",
|
||||
"mask",
|
||||
"masked_image_latents",
|
||||
]
|
||||
@@ -1671,13 +1668,8 @@ class StableDiffusionXLInpaintPipeline(
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
|
||||
mask = callback_outputs.pop("mask", mask)
|
||||
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user