mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add skip_layers argument to SD3 transformer model class (#9880)
* add skip_layers argument to SD3 transformer model class * add unit test for skip_layers in stable diffusion 3 * sd3: pipeline should support skip layer guidance * up --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
@@ -268,6 +268,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
block_controlnet_hidden_states: List = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
skip_layers: Optional[List[int]] = None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
@@ -279,9 +280,9 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
block_controlnet_hidden_states (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
@@ -290,6 +291,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
skip_layers (`list` of `int`, *optional*):
|
||||
A list of layer indices to skip during the forward pass.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
@@ -317,7 +320,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# Skip specified layers
|
||||
is_skip = True if skip_layers is not None and index_block in skip_layers else False
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
@@ -336,8 +342,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
elif not is_skip:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
|
||||
@@ -642,6 +642,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def skip_guidance_layers(self):
|
||||
return self._skip_guidance_layers
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
@@ -694,6 +698,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
skip_guidance_layers: List[int] = None,
|
||||
skip_layer_guidance_scale: int = 2.8,
|
||||
skip_layer_guidance_stop: int = 0.2,
|
||||
skip_layer_guidance_start: int = 0.01,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -778,6 +786,22 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
skip_guidance_layers (`List[int]`, *optional*):
|
||||
A list of integers that specify layers to skip during guidance. If not provided, all layers will be
|
||||
used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
|
||||
Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
|
||||
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
|
||||
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
|
||||
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
|
||||
with a scale of `1`.
|
||||
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
|
||||
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
|
||||
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
|
||||
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
|
||||
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
|
||||
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
|
||||
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
|
||||
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -809,6 +833,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
@@ -851,6 +876,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
if skip_guidance_layers is not None:
|
||||
original_prompt_embeds = prompt_embeds
|
||||
original_pooled_prompt_embeds = pooled_prompt_embeds
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
@@ -879,7 +907,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2)
|
||||
if self.do_classifier_free_guidance and skip_guidance_layers is None
|
||||
else latents
|
||||
)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
@@ -896,6 +928,25 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
should_skip_layers = (
|
||||
True
|
||||
if i > num_inference_steps * skip_layer_guidance_start
|
||||
and i < num_inference_steps * skip_layer_guidance_stop
|
||||
else False
|
||||
)
|
||||
if skip_guidance_layers is not None and should_skip_layers:
|
||||
noise_pred_skip_layers = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=original_prompt_embeds,
|
||||
pooled_projections=original_pooled_prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
skip_layers=skip_guidance_layers,
|
||||
)[0]
|
||||
noise_pred = (
|
||||
noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
|
||||
@@ -147,3 +147,23 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SD3Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_skip_layers(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Forward pass without skipping layers
|
||||
output_full = model(**inputs_dict).sample
|
||||
|
||||
# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
|
||||
inputs_dict_with_skip = inputs_dict.copy()
|
||||
inputs_dict_with_skip["skip_layers"] = [0]
|
||||
output_skip = model(**inputs_dict_with_skip).sample
|
||||
|
||||
# Check that the outputs are different
|
||||
self.assertFalse(
|
||||
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
|
||||
)
|
||||
|
||||
# Check that the outputs have the same shape
|
||||
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
|
||||
|
||||
Reference in New Issue
Block a user