mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add callbacks to WuerstchenDecoderPipeline and WuerstchenCombinedPipeline (#5154)
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -202,6 +202,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -240,6 +242,12 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -315,7 +323,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler)
|
||||
|
||||
# 6. Run denoising loop
|
||||
for t in self.progress_bar(timesteps[:-1]):
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
ratio = t.expand(latents.size(0)).to(dtype)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
@@ -343,6 +351,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Scale and decode the image latents with vq-vae
|
||||
latents = self.vqgan.config.scale_factor * latents
|
||||
images = self.vqgan.decode(latents).sample.clamp(0, 1)
|
||||
|
||||
@@ -161,6 +161,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
prior_callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
prior_callback_steps: int = 1,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -222,6 +226,18 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
prior_callback (`Callable`, *optional*):
|
||||
A function that will be called every `prior_callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `prior_callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
prior_callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -244,6 +260,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
latents=latents,
|
||||
output_type="pt",
|
||||
return_dict=False,
|
||||
callback=prior_callback,
|
||||
callback_steps=prior_callback_steps,
|
||||
)
|
||||
image_embeddings = prior_outputs[0]
|
||||
|
||||
@@ -257,6 +275,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
generator=generator,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user