mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Adds denoising_end parameter to ControlNetPipeline for SDXL (#6175)
* Initial commit * Removed copy hints, as in original SDXLControlNetPipeline Removed copy hints, as in original SDXLControlNetPipeline, as the `make fix-copies` seems to have issues with the @property decorator. * Reverted changes to ControlNetXS * Addendum to: Removed changes to ControlNetXS * Added test+docs for mixture of denoiser * Update docs/source/en/using-diffusers/controlnet.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/controlnet.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -429,6 +429,27 @@ image = pipe(
|
||||
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`.
|
||||
See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model.
|
||||
Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`.
|
||||
|
||||
```py
|
||||
base = StableDiffusionXLControlNetPipeline(...)
|
||||
image = base(
|
||||
prompt=prompt,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
image=canny_image,
|
||||
num_inference_steps=40,
|
||||
denoising_end=0.8,
|
||||
output_type="latent",
|
||||
).images
|
||||
# rest exactly as with StableDiffusionXLPipeline
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## MultiControlNet
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -916,6 +916,10 @@ class StableDiffusionXLControlNetPipeline(
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def denoising_end(self):
|
||||
return self._denoising_end
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
@@ -930,6 +934,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
@@ -989,6 +994,13 @@ class StableDiffusionXLControlNetPipeline(
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
@@ -1151,6 +1163,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1325,6 +1338,23 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8.1 Apply denoising_end
|
||||
if (
|
||||
self.denoising_end is not None
|
||||
and isinstance(self.denoising_end, float)
|
||||
and self.denoising_end > 0
|
||||
and self.denoising_end < 1
|
||||
):
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
is_unet_compiled = is_compiled_module(self.unet)
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
@@ -24,8 +25,10 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LCMScheduler,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
|
||||
@@ -364,6 +367,110 @@ class StableDiffusionXLControlNetPipelineFastTests(
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
# copied from test_stable_diffusion_xl.py:test_stable_diffusion_two_xl_mixture_of_denoiser_fast
|
||||
# with `StableDiffusionXLControlNetPipeline` instead of `StableDiffusionXLPipeline`
|
||||
def test_controlnet_sdxl_two_mixture_of_denoiser_fast(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe_1 = StableDiffusionXLControlNetPipeline(**components).to(torch_device)
|
||||
pipe_1.unet.set_default_attn_processor()
|
||||
|
||||
components_without_controlnet = {k: v for k, v in components.items() if k != "controlnet"}
|
||||
pipe_2 = StableDiffusionXLImg2ImgPipeline(**components_without_controlnet).to(torch_device)
|
||||
pipe_2.unet.set_default_attn_processor()
|
||||
|
||||
def assert_run_mixture(
|
||||
num_steps,
|
||||
split,
|
||||
scheduler_cls_orig,
|
||||
expected_tss,
|
||||
num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps,
|
||||
):
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = num_steps
|
||||
|
||||
class scheduler_cls(scheduler_cls_orig):
|
||||
pass
|
||||
|
||||
pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
|
||||
pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
|
||||
|
||||
# Let's retrieve the number of timesteps we want to use
|
||||
pipe_1.scheduler.set_timesteps(num_steps)
|
||||
expected_steps = pipe_1.scheduler.timesteps.tolist()
|
||||
|
||||
if pipe_1.scheduler.order == 2:
|
||||
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
|
||||
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
|
||||
expected_steps = expected_steps_1 + expected_steps_2
|
||||
else:
|
||||
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
|
||||
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
|
||||
|
||||
# now we monkey patch step `done_steps`
|
||||
# list into the step function for testing
|
||||
done_steps = []
|
||||
old_step = copy.copy(scheduler_cls.step)
|
||||
|
||||
def new_step(self, *args, **kwargs):
|
||||
done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t`
|
||||
return old_step(self, *args, **kwargs)
|
||||
|
||||
scheduler_cls.step = new_step
|
||||
|
||||
inputs_1 = {
|
||||
**inputs,
|
||||
**{
|
||||
"denoising_end": 1.0 - (split / num_train_timesteps),
|
||||
"output_type": "latent",
|
||||
},
|
||||
}
|
||||
latents = pipe_1(**inputs_1).images[0]
|
||||
|
||||
assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
|
||||
|
||||
inputs_2 = {
|
||||
**inputs,
|
||||
**{
|
||||
"denoising_start": 1.0 - (split / num_train_timesteps),
|
||||
"image": latents,
|
||||
},
|
||||
}
|
||||
pipe_2(**inputs_2).images[0]
|
||||
|
||||
assert expected_steps_2 == done_steps[len(expected_steps_1) :]
|
||||
assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
|
||||
|
||||
steps = 10
|
||||
for split in [300, 700]:
|
||||
for scheduler_cls_timesteps in [
|
||||
(EulerDiscreteScheduler, [901, 801, 701, 601, 501, 401, 301, 201, 101, 1]),
|
||||
(
|
||||
HeunDiscreteScheduler,
|
||||
[
|
||||
901.0,
|
||||
801.0,
|
||||
801.0,
|
||||
701.0,
|
||||
701.0,
|
||||
601.0,
|
||||
601.0,
|
||||
501.0,
|
||||
501.0,
|
||||
401.0,
|
||||
401.0,
|
||||
301.0,
|
||||
301.0,
|
||||
201.0,
|
||||
201.0,
|
||||
101.0,
|
||||
101.0,
|
||||
1.0,
|
||||
1.0,
|
||||
],
|
||||
),
|
||||
]:
|
||||
assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1])
|
||||
|
||||
|
||||
class StableDiffusionXLMultiControlNetPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
|
||||
Reference in New Issue
Block a user