1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Extend Support for callback_on_step_end for AuraFlow and LuminaText2Img Pipelines (#10746)

* Add support for callback_on_step_end for
AuraFlowPipeline and LuminaText2ImgPipeline.

* Apply the suggestions from code review for lumina and auraflow

Co-authored-by: hlky <hlky@hlky.ac>

* Update missing inputs and imports.

* Add input field.

* Apply suggestions from code review-2

Co-authored-by: hlky <hlky@hlky.ac>

* Apply the suggestions from review for unused imports.

Co-authored-by: hlky <hlky@hlky.ac>

* make style.

* Update pipeline_aura_flow.py

* Update pipeline_lumina.py

* Update pipeline_lumina.py

* Update pipeline_aura_flow.py

* Update pipeline_lumina.py

---------

Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
Parag Ekbote
2025-02-16 09:28:57 -08:00
committed by GitHub
parent 952b9131a2
commit 3e99b5677e
2 changed files with 80 additions and 2 deletions

View File

@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5Tokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
@@ -131,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline):
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
]
def __init__(
self,
@@ -159,12 +164,19 @@ class AuraFlowPipeline(DiffusionPipeline):
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -387,6 +399,14 @@ class AuraFlowPipeline(DiffusionPipeline):
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -408,6 +428,10 @@ class AuraFlowPipeline(DiffusionPipeline):
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
@@ -462,6 +486,15 @@ class AuraFlowPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
@@ -483,8 +516,11 @@ class AuraFlowPipeline(DiffusionPipeline):
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -541,6 +577,7 @@ class AuraFlowPipeline(DiffusionPipeline):
# 6. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
@@ -567,6 +604,15 @@ class AuraFlowPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

View File

@@ -17,11 +17,12 @@ import inspect
import math
import re
import urllib.parse as ul
from typing import List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import AutoModel, AutoTokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...models.embeddings import get_2d_rotary_pos_embed_lumina
@@ -174,6 +175,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
]
def __init__(
self,
@@ -395,12 +400,20 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -644,6 +657,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
max_sequence_length: int = 256,
scaling_watershed: Optional[float] = 1.0,
proportional_attn: Optional[bool] = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -735,7 +752,11 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
cross_attention_kwargs = {}
# 2. Define call parameters
@@ -797,6 +818,8 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
latents,
)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -886,6 +909,15 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
progress_bar.update()
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if XLA_AVAILABLE:
xm.mark_step()