mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* implement marigold depth and normals pipelines in diffusers core * remove bibtex * remove deprecations * remove save_memory argument * remove validate_vae * remove config output * remove batch_size autodetection * remove presets logic move default denoising_steps and processing_resolution into the model config make default ensemble_size 1 * remove no_grad * add fp16 to the example usage * implement is_matplotlib_available use is_matplotlib_available, is_scipy_available for conditional imports in the marigold depth pipeline * move colormap, visualize_depth, and visualize_normals into export_utils.py * make the denoising loop more lucid fix the outputs to always be 4d tensors or lists of pil images support a 4d input_image case attempt to support model_cpu_offload_seq move check_inputs into a separate function change default batch_size to 1, remove any logic to make it bigger implicitly * style * rename denoising_steps into num_inference_steps * rename input_image into image * rename input_latent into latents * remove decode_image change decode_prediction to use the AutoencoderKL.decode method * move clean_latent outside of progress_bar * refactor marigold-reusable image processing bits into MarigoldImageProcessor class * clean up the usage example docstring * make ensemble functions members of the pipelines * add early checks in check_inputs rename E into ensemble_size in depth ensembling * fix vae_scale_factor computation * better compatibility with torch.compile better variable naming * move export_depth_to_png to export_utils * remove encode_prediction * improve visualize_depth and visualize_normals to accept multi-dimensional data and lists remove visualization functions from the pipelines move exporting depth as 16-bit PNGs functionality from the depth pipeline update example docstrings * do not shortcut vae.config variables * change all asserts to raise ValueError * rename output_prediction_type to output_type * better variable names clean up variable deletion code * better variable names * pass desc and leave kwargs into the diffusers progress_bar implement nested progress bar for images and steps loops * implement scale_invariant and shift_invariant flags in the ensemble_depth function add scale_invariant and shift_invariant flags readout from the model config further refactor ensemble_depth support ensembling without alignment add ensemble_depth docstring * fix generator device placement checks * move encode_empty_text body into the pipeline call * minor empty text encoding simplifications * adjust pipelines' class docstrings to explain the added construction arguments * improve the scipy failure condition add comments improve docstrings change the default use_full_z_range to True * make input image values range check configurable in the preprocessor refactor load_image_canonical in preprocessor to reject unknown types and return the image in the expected 4D format of tensor and on right device support a list of everything as inputs to the pipeline, change type to PipelineImageInput implement a check that all input list elements have the same dimensions improve docstrings of pipeline outputs remove check_input pipeline argument * remove forgotten print * add prediction_type model config * add uncertainty visualization into export utils fix NaN values in normals uncertainties * change default of output_uncertainty to False better handle the case of an attempt to export or visualize none * fix `output_uncertainty=False` * remove kwargs fix check_inputs according to the new inputs of the pipeline * rename prepare_latent into prepare_latents as in other pipelines annotate prepare_latents in normals pipeline with "Copied from" annotate encode_image in normals pipeline with "Copied from" * move nested-capable `progress_bar` method into the pipelines revert the original `progress_bar` method in pipeline_utils * minor message improvement * fix cpu offloading * move colormap, visualize_depth, export_depth_to_16bit_png, visualize_normals, visualize_uncertainty to marigold_image_processing.py update example docstrings * fix missing comma * change torch.FloatTensor to torch.Tensor * fix importing of MarigoldImageProcessor * fix vae offloading fix batched image encoding remove separate encode_image function and use vae.encode instead * implement marigold's intial tests relax generator checks in line with other pipelines implement return_dict __call__ argument in line with other pipelines * fix num_images computation * remove MarigoldImageProcessor and outputs from import structure update tests * update docstrings * update init * update * style * fix * fix * up * up * up * add simple test * up * update expected np input/output to be channel last * move expand_tensor_or_array into the MarigoldImageProcessor * rewrite tests to follow conventions - hardcoded slices instead of image artifacts write more smoke tests * add basic docs. * add anton's contribution statement * remove todos. * fix assertion values for marigold depth slow tests * fix assertion values for depth normals. * remove print * support AutoencoderTiny in the pipelines * update documentation page add Available Pipelines section add Available Checkpoints section add warning about num_inference_steps * fix missing import in docstring fix wrong value in visualize_depth docstring * [doc] add marigold to pipelines overview * [doc] add section "usage examples" * fix an issue with latents check in the pipelines * add "Frame-by-frame Video Processing with Consistency" section * grammarly * replace tables with images with css-styled images (blindly) * style * print * fix the assertions. * take from the github runner. * take the slices from action artifacts * style. * update with the slices from the runner. * remove unnecessary code blocks. * Revert "[doc] add marigold to pipelines overview" This reverts commit a505165150afd8dab23c474d1a054ea505a56a5f. * remove invitation for new modalities * split out marigold usage examples * doc cleanup --------- Co-authored-by: yiyixuxu <yixu310@gmail.com> Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
399 lines
24 KiB
Markdown
399 lines
24 KiB
Markdown
<!--Copyright 2024 Marigold authors and The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# Marigold Pipelines for Computer Vision Tasks
|
|
|
|
[Marigold](marigold) is a novel diffusion-based dense prediction approach, and a set of pipelines for various computer vision tasks, such as monocular depth estimation.
|
|
|
|
This guide will show you how to use Marigold to obtain fast and high-quality predictions for images and videos.
|
|
|
|
Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image.
|
|
Currently, the following tasks are implemented:
|
|
|
|
| Pipeline | Predicted Modalities | Demos |
|
|
|---------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------:|
|
|
| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-lcm), [Slow Original Demo (DDIM)](https://huggingface.co/spaces/prs-eth/marigold) |
|
|
| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-normals-lcm) |
|
|
|
|
The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization.
|
|
These checkpoints are meant to work with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold).
|
|
The original code can also be used to train new checkpoints.
|
|
|
|
| Checkpoint | Modality | Comment |
|
|
|-----------------------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|
| [prs-eth/marigold-v1-0](https://huggingface.co/prs-eth/marigold-v1-0) | Depth | The first Marigold Depth checkpoint, which predicts *affine-invariant depth* maps. The performance of this checkpoint in benchmarks was studied in the original [paper](https://huggingface.co/papers/2312.02145). Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. Affine-invariant depth prediction has a range of values in each pixel between 0 (near plane) and 1 (far plane); both planes are chosen by the model as part of the inference process. See the `MarigoldImageProcessor` reference for visualization utilities. |
|
|
| [prs-eth/marigold-lcm-v1-0](https://huggingface.co/prs-eth/marigold-lcm-v1-0) | Depth | The fast Marigold Depth checkpoint, fine-tuned from `prs-eth/marigold-v1-0`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. |
|
|
| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | A preview checkpoint for the Marigold Normals pipeline. Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. The surface normals predictions are unit-length 3D vectors with values in the range from -1 to 1. *This checkpoint will be phased out after the release of `v1-0` version.* |
|
|
| [prs-eth/marigold-normals-lcm-v0-1](https://huggingface.co/prs-eth/marigold-normals-lcm-v0-1) | Normals | The fast Marigold Normals checkpoint, fine-tuned from `prs-eth/marigold-normals-v0-1`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. *This checkpoint will be phased out after the release of `v1-0` version.* |
|
|
The examples below are mostly given for depth prediction, but they can be universally applied with other supported modalities.
|
|
We showcase the predictions using the same input image of Albert Einstein generated by Midjourney.
|
|
This makes it easier to compare visualizations of the predictions across various modalities and checkpoints.
|
|
|
|
<div class="flex gap-4" style="justify-content: center; width: 100%;">
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://marigoldmonodepth.github.io/images/einstein.jpg"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Example input image for all Marigold pipelines
|
|
</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
### Depth Prediction Quick Start
|
|
|
|
To get the first depth prediction, load `prs-eth/marigold-depth-lcm-v1-0` checkpoint into `MarigoldDepthPipeline` pipeline, put the image through the pipeline, and save the predictions:
|
|
|
|
```python
|
|
import diffusers
|
|
import torch
|
|
|
|
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
|
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
|
).to("cuda")
|
|
|
|
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
|
depth = pipe(image)
|
|
|
|
vis = pipe.image_processor.visualize_depth(depth.prediction)
|
|
vis[0].save("einstein_depth.png")
|
|
|
|
depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)
|
|
depth_16bit[0].save("einstein_depth_16bit.png")
|
|
```
|
|
|
|
The visualization function for depth [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] applies one of [matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]` depth range into an RGB image.
|
|
With the `Spectral` colormap, pixels with near depth are painted red, and far pixels are assigned blue color.
|
|
The 16-bit PNG file stores the single channel values mapped linearly from the `[0, 1]` range into `[0, 65535]`.
|
|
Below are the raw and the visualized predictions; as can be seen, dark areas (mustache) are easier to distinguish in the visualization:
|
|
|
|
<div class="flex gap-4">
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth_16bit.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Predicted depth (16-bit PNG)
|
|
</figcaption>
|
|
</div>
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Predicted depth visualization (Spectral)
|
|
</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
### Surface Normals Prediction Quick Start
|
|
|
|
Load `prs-eth/marigold-normals-lcm-v0-1` checkpoint into `MarigoldNormalsPipeline` pipeline, put the image through the pipeline, and save the predictions:
|
|
|
|
```python
|
|
import diffusers
|
|
import torch
|
|
|
|
pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
|
|
"prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
|
|
).to("cuda")
|
|
|
|
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
|
normals = pipe(image)
|
|
|
|
vis = pipe.image_processor.visualize_normals(normals.prediction)
|
|
vis[0].save("einstein_normals.png")
|
|
```
|
|
|
|
The visualization function for normals [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional prediction with pixel values in the range `[-1, 1]` into an RGB image.
|
|
The visualization function supports flipping surface normals axes to make the visualization compatible with other choices of the frame of reference.
|
|
Conceptually, each pixel is painted according to the surface normal vector in the frame of reference, where `X` axis points right, `Y` axis points up, and `Z` axis points at the viewer.
|
|
Below is the visualized prediction:
|
|
|
|
<div class="flex gap-4" style="justify-content: center; width: 100%;">
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_normals.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Predicted surface normals visualization
|
|
</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
In this example, the nose tip almost certainly has a point on the surface, in which the surface normal vector points straight at the viewer, meaning that its coordinates are `[0, 0, 1]`.
|
|
This vector maps to the RGB `[128, 128, 255]`, which corresponds to the violet-blue color.
|
|
Similarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the red hue.
|
|
Points on the shoulders pointing up with a large `Y` promote green color.
|
|
|
|
### Speeding up inference
|
|
|
|
The above quick start snippets are already optimized for speed: they load the LCM checkpoint, use the `fp16` variant of weights and computation, and perform just one denoising diffusion step.
|
|
The `pipe(image)` call completes in 280ms on RTX 3090 GPU.
|
|
Internally, the input image is encoded with the Stable Diffusion VAE encoder, then the U-Net performs one denoising step, and finally, the prediction latent is decoded with the VAE decoder into pixel space.
|
|
In this case, two out of three module calls are dedicated to converting between pixel and latent space of LDM.
|
|
Because Marigold's latent space is compatible with the base Stable Diffusion, it is possible to speed up the pipeline call by more than 3x (85ms on RTX 3090) by using a [lightweight replacement of the SD VAE](autoencoder_tiny):
|
|
|
|
```diff
|
|
import diffusers
|
|
import torch
|
|
|
|
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
|
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
|
).to("cuda")
|
|
|
|
+ pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
|
|
+ "madebyollin/taesd", torch_dtype=torch.float16
|
|
+ ).cuda()
|
|
|
|
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
|
depth = pipe(image)
|
|
```
|
|
|
|
As suggested in [Optimizations](torch2.0), adding `torch.compile` may squeeze extra performance depending on the target hardware:
|
|
|
|
```diff
|
|
import diffusers
|
|
import torch
|
|
|
|
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
|
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
|
).to("cuda")
|
|
|
|
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
|
|
|
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
|
depth = pipe(image)
|
|
```
|
|
|
|
## Qualitative Comparison with Depth Anything
|
|
|
|
With the above speed optimizations, Marigold delivers predictions with more details and faster than [Depth Anything](https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything) with the largest checkpoint [LiheYoung/depth-anything-large-hf](https://huggingface.co/LiheYoung/depth-anything-large-hf):
|
|
|
|
<div class="flex gap-4">
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Marigold LCM fp16 with Tiny AutoEncoder
|
|
</figcaption>
|
|
</div>
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/bfe7cb56ca1cc0811b328212472350879dfa7f8b/marigold/einstein_depthanything_large.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Depth Anything Large
|
|
</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
## Maximizing Precision and Ensembling
|
|
|
|
Marigold pipelines have a built-in ensembling mechanism combining multiple predictions from different random latents.
|
|
This is a brute-force way of improving the precision of predictions, capitalizing on the generative nature of diffusion.
|
|
The ensembling path is activated automatically when the `ensemble_size` argument is set greater than `1`.
|
|
When aiming for maximum precision, it makes sense to adjust `num_inference_steps` simultaneously with `ensemble_size`.
|
|
The recommended values vary across checkpoints but primarily depend on the scheduler type.
|
|
The effect of ensembling is particularly well-seen with surface normals:
|
|
|
|
```python
|
|
import diffusers
|
|
|
|
model_path = "prs-eth/marigold-normals-v1-0"
|
|
|
|
model_paper_kwargs = {
|
|
diffusers.schedulers.DDIMScheduler: {
|
|
"num_inference_steps": 10,
|
|
"ensemble_size": 10,
|
|
},
|
|
diffusers.schedulers.LCMScheduler: {
|
|
"num_inference_steps": 4,
|
|
"ensemble_size": 5,
|
|
},
|
|
}
|
|
|
|
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
|
|
|
pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(model_path).to("cuda")
|
|
pipe_kwargs = model_paper_kwargs[type(pipe.scheduler)]
|
|
|
|
depth = pipe(image, **pipe_kwargs)
|
|
|
|
vis = pipe.image_processor.visualize_normals(depth.prediction)
|
|
vis[0].save("einstein_normals.png")
|
|
```
|
|
|
|
<div class="flex gap-4">
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_normals.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Surface normals, no ensembling
|
|
</figcaption>
|
|
</div>
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_normals.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Surface normals, with ensembling
|
|
</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
As can be seen, all areas with fine-grained structurers, such as hair, got more conservative and on average more correct predictions.
|
|
Such a result is more suitable for precision-sensitive downstream tasks, such as 3D reconstruction.
|
|
|
|
## Quantitative Evaluation
|
|
|
|
To evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets), follow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values for `num_inference_steps` and `ensemble_size`.
|
|
Optionally seed randomness to ensure reproducibility. Maximizing `batch_size` will deliver maximum device utilization.
|
|
|
|
```python
|
|
import diffusers
|
|
import torch
|
|
|
|
device = "cuda"
|
|
seed = 2024
|
|
model_path = "prs-eth/marigold-v1-0"
|
|
|
|
model_paper_kwargs = {
|
|
diffusers.schedulers.DDIMScheduler: {
|
|
"num_inference_steps": 50,
|
|
"ensemble_size": 10,
|
|
},
|
|
diffusers.schedulers.LCMScheduler: {
|
|
"num_inference_steps": 4,
|
|
"ensemble_size": 10,
|
|
},
|
|
}
|
|
|
|
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed)
|
|
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(model_path).to(device)
|
|
pipe_kwargs = model_paper_kwargs[type(pipe.scheduler)]
|
|
|
|
depth = pipe(image, generator=generator, **pipe_kwargs)
|
|
|
|
# evaluate metrics
|
|
```
|
|
|
|
## Using Predictive Uncertainty
|
|
|
|
The ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random latents.
|
|
As a side effect, it can be used to quantify epistemic (model) uncertainty; simply specify `ensemble_size` greater than 1 and set `output_uncertainty=True`.
|
|
The resulting uncertainty will be available in the `uncertainty` field of the output.
|
|
It can be visualized as follows:
|
|
|
|
```python
|
|
import diffusers
|
|
import torch
|
|
|
|
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
|
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
|
).to("cuda")
|
|
|
|
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
|
depth = pipe(
|
|
image,
|
|
ensemble_size=10, # any number greater than 1; higher values yield higher precision
|
|
output_uncertainty=True,
|
|
)
|
|
|
|
uncertainty = pipe.image_processor.visualize_uncertainty(depth.uncertainty)
|
|
uncertainty[0].save("einstein_depth_uncertainty.png")
|
|
```
|
|
|
|
<div class="flex gap-4">
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_depth_uncertainty.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Depth uncertainty
|
|
</figcaption>
|
|
</div>
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_normals_uncertainty.png"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
|
Surface normals uncertainty
|
|
</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
The interpretation of uncertainty is easy: higher values (white) correspond to pixels, where the model struggles to make consistent predictions.
|
|
Evidently, the depth model is the least confident around edges with discontinuity, where the object depth changes drastically.
|
|
The surface normals model is the least confident in fine-grained structures, such as hair, and dark areas, such as the collar.
|
|
|
|
## Frame-by-frame Video Processing with Temporal Consistency
|
|
|
|
Due to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent initialization.
|
|
This becomes an obvious drawback compared to traditional end-to-end dense regression networks, as exemplified in the following videos:
|
|
|
|
<div class="flex gap-4">
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/25024b5443a6c1357492751fd09355bd3f967845/marigold/marigold_obama.gif"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">Input video</figcaption>
|
|
</div>
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/25024b5443a6c1357492751fd09355bd3f967845/marigold/marigold_obama_depth_independent.gif"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">Marigold Depth applied to input video frames independently</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
To address this issue, it is possible to pass `latents` argument to the pipelines, which defines the starting point of diffusion.
|
|
Empirically, we found that a convex combination of the very same starting point noise latent and the latent corresponding to the previous frame prediction give sufficiently smooth results, as implemented in the snippet below:
|
|
|
|
```python
|
|
import imageio
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
import diffusers
|
|
import torch
|
|
|
|
device = "cuda"
|
|
path_in = "obama.mp4"
|
|
path_out = "obama_depth.gif"
|
|
|
|
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
|
"prs-eth/marigold-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
|
).to(device)
|
|
pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
|
|
"madebyollin/taesd", torch_dtype=torch.float16
|
|
).to(device)
|
|
pipe.set_progress_bar_config(disable=True)
|
|
|
|
with imageio.get_reader(path_in) as reader:
|
|
size = reader.get_meta_data()['size']
|
|
last_frame_latent = None
|
|
latent_common = torch.randn(
|
|
(1, 4, 768 * size[1] // (8 * max(size)), 768 * size[0] // (8 * max(size)))
|
|
).to(device=device, dtype=torch.float16)
|
|
|
|
out = []
|
|
for frame_id, frame in tqdm(enumerate(reader), desc="Processing Video"):
|
|
frame = Image.fromarray(frame)
|
|
latents = latent_common
|
|
if last_frame_latent is not None:
|
|
latents = 0.9 * latents + 0.1 * last_frame_latent
|
|
|
|
depth = pipe(
|
|
frame, match_input_resolution=False, latents=latents, output_latent=True,
|
|
)
|
|
last_frame_latent = depth.latent
|
|
out.append(pipe.image_processor.visualize_depth(depth.prediction)[0])
|
|
|
|
diffusers.utils.export_to_gif(out, path_out, fps=reader.get_meta_data()['fps'])
|
|
```
|
|
|
|
Here, the diffusion process starts from the given computed latent.
|
|
The pipeline sets `output_latent=True` to access `out.latent` and computes its contribution to the next frame's latent initialization.
|
|
The result is much more stable now:
|
|
|
|
<div class="flex gap-4">
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/25024b5443a6c1357492751fd09355bd3f967845/marigold/marigold_obama_depth_independent.gif"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">Marigold Depth applied to input video frames independently</figcaption>
|
|
</div>
|
|
<div style="flex: 1 1 50%; max-width: 50%;">
|
|
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/25024b5443a6c1357492751fd09355bd3f967845/marigold/marigold_obama_depth_consistent.gif"/>
|
|
<figcaption class="mt-1 text-center text-sm text-gray-500">Marigold Depth with forced latents initialization</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
Hopefully, you will find Marigold useful for solving your downstream tasks, be it a part of a more broad generative workflow, or a broader perception task, such as 3D reconstruction. |