1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add skip_layers argument to SD3 transformer model class (#9880)

* add skip_layers argument to SD3 transformer model class

* add unit test for skip_layers in stable diffusion 3

* sd3: pipeline should support skip layer guidance

* up

---------

Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
Bagheera
2024-11-19 14:22:54 -06:00
committed by GitHub
parent cc7d88f247
commit 99c0483b67
3 changed files with 82 additions and 6 deletions

View File

@@ -268,6 +268,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
skip_layers: Optional[List[int]] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
@@ -279,9 +280,9 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
block_controlnet_hidden_states (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
@@ -290,6 +291,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
skip_layers (`list` of `int`, *optional*):
A list of layer indices to skip during the forward pass.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -317,7 +320,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
# Skip specified layers
is_skip = True if skip_layers is not None and index_block in skip_layers else False
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
@@ -336,8 +342,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
temb,
**ckpt_kwargs,
)
else:
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)

View File

@@ -642,6 +642,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
def guidance_scale(self):
return self._guidance_scale
@property
def skip_guidance_layers(self):
return self._skip_guidance_layers
@property
def clip_skip(self):
return self._clip_skip
@@ -694,6 +698,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
skip_guidance_layers: List[int] = None,
skip_layer_guidance_scale: int = 2.8,
skip_layer_guidance_stop: int = 0.2,
skip_layer_guidance_start: int = 0.01,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -778,6 +786,22 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
skip_guidance_layers (`List[int]`, *optional*):
A list of integers that specify layers to skip during guidance. If not provided, all layers will be
used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
with a scale of `1`.
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
Examples:
@@ -809,6 +833,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
)
self._guidance_scale = guidance_scale
self._skip_layer_guidance_scale = skip_layer_guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
@@ -851,6 +876,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
)
if self.do_classifier_free_guidance:
if skip_guidance_layers is not None:
original_prompt_embeds = prompt_embeds
original_pooled_prompt_embeds = pooled_prompt_embeds
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
@@ -879,7 +907,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = (
torch.cat([latents] * 2)
if self.do_classifier_free_guidance and skip_guidance_layers is None
else latents
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
@@ -896,6 +928,25 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
should_skip_layers = (
True
if i > num_inference_steps * skip_layer_guidance_start
and i < num_inference_steps * skip_layer_guidance_stop
else False
)
if skip_guidance_layers is not None and should_skip_layers:
noise_pred_skip_layers = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=original_prompt_embeds,
pooled_projections=original_pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
skip_layers=skip_guidance_layers,
)[0]
noise_pred = (
noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype

View File

@@ -147,3 +147,23 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_skip_layers(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Forward pass without skipping layers
output_full = model(**inputs_dict).sample
# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
inputs_dict_with_skip = inputs_dict.copy()
inputs_dict_with_skip["skip_layers"] = [0]
output_skip = model(**inputs_dict_with_skip).sample
# Check that the outputs are different
self.assertFalse(
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
)
# Check that the outputs have the same shape
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")