1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add callbacks to WuerstchenDecoderPipeline and WuerstchenCombinedPipeline (#5154)

This commit is contained in:
Carson Katri
2023-09-25 13:26:53 -04:00
committed by GitHub
parent 28254c79b6
commit 6281d2066b
2 changed files with 33 additions and 2 deletions

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import numpy as np
import torch
@@ -202,6 +202,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
@@ -240,6 +242,12 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
@@ -315,7 +323,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop
for t in self.progress_bar(timesteps[:-1]):
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
@@ -343,6 +351,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
generator=generator,
).prev_sample
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Scale and decode the image latents with vq-vae
latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1)

View File

@@ -161,6 +161,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
prior_callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
prior_callback_steps: int = 1,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
@@ -222,6 +226,18 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback (`Callable`, *optional*):
A function that will be called every `prior_callback_steps` steps during inference. The function will be
called with the following arguments: `prior_callback(step: int, timestep: int, latents: torch.FloatTensor)`.
prior_callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
@@ -244,6 +260,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
latents=latents,
output_type="pt",
return_dict=False,
callback=prior_callback,
callback_steps=prior_callback_steps,
)
image_embeddings = prior_outputs[0]
@@ -257,6 +275,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
generator=generator,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)
return outputs