mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[docs] Add controlnet example to marigold (#8289)
* initial doc * fix wrong LCM sentence * implement binary colormap without requiring matplotlib update section about Marigold for ControlNet update formatting of marigold_usage.md * fix indentation --------- Co-authored-by: anton <anton.obukhov@gmail.com>
This commit is contained in:
@@ -14,9 +14,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
[Marigold](marigold) is a novel diffusion-based dense prediction approach, and a set of pipelines for various computer vision tasks, such as monocular depth estimation.
|
||||
|
||||
This guide will show you how to use Marigold to obtain fast and high-quality predictions for images and videos.
|
||||
This guide will show you how to use Marigold to obtain fast and high-quality predictions for images and videos.
|
||||
|
||||
Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image.
|
||||
Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image.
|
||||
Currently, the following tasks are implemented:
|
||||
|
||||
| Pipeline | Predicted Modalities | Demos |
|
||||
@@ -24,8 +24,8 @@ Currently, the following tasks are implemented:
|
||||
| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-lcm), [Slow Original Demo (DDIM)](https://huggingface.co/spaces/prs-eth/marigold) |
|
||||
| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-normals-lcm) |
|
||||
|
||||
The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization.
|
||||
These checkpoints are meant to work with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold).
|
||||
The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization.
|
||||
These checkpoints are meant to work with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold).
|
||||
The original code can also be used to train new checkpoints.
|
||||
|
||||
| Checkpoint | Modality | Comment |
|
||||
@@ -34,22 +34,22 @@ The original code can also be used to train new checkpoints.
|
||||
| [prs-eth/marigold-lcm-v1-0](https://huggingface.co/prs-eth/marigold-lcm-v1-0) | Depth | The fast Marigold Depth checkpoint, fine-tuned from `prs-eth/marigold-v1-0`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. |
|
||||
| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | A preview checkpoint for the Marigold Normals pipeline. Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. The surface normals predictions are unit-length 3D vectors with values in the range from -1 to 1. *This checkpoint will be phased out after the release of `v1-0` version.* |
|
||||
| [prs-eth/marigold-normals-lcm-v0-1](https://huggingface.co/prs-eth/marigold-normals-lcm-v0-1) | Normals | The fast Marigold Normals checkpoint, fine-tuned from `prs-eth/marigold-normals-v0-1`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. *This checkpoint will be phased out after the release of `v1-0` version.* |
|
||||
The examples below are mostly given for depth prediction, but they can be universally applied with other supported modalities.
|
||||
We showcase the predictions using the same input image of Albert Einstein generated by Midjourney.
|
||||
The examples below are mostly given for depth prediction, but they can be universally applied with other supported modalities.
|
||||
We showcase the predictions using the same input image of Albert Einstein generated by Midjourney.
|
||||
This makes it easier to compare visualizations of the predictions across various modalities and checkpoints.
|
||||
|
||||
<div class="flex gap-4" style="justify-content: center; width: 100%;">
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://marigoldmonodepth.github.io/images/einstein.jpg"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Example input image for all Marigold pipelines
|
||||
</figcaption>
|
||||
Example input image for all Marigold pipelines
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Depth Prediction Quick Start
|
||||
|
||||
To get the first depth prediction, load `prs-eth/marigold-depth-lcm-v1-0` checkpoint into `MarigoldDepthPipeline` pipeline, put the image through the pipeline, and save the predictions:
|
||||
To get the first depth prediction, load `prs-eth/marigold-depth-lcm-v1-0` checkpoint into `MarigoldDepthPipeline` pipeline, put the image through the pipeline, and save the predictions:
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
@@ -69,7 +69,7 @@ depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)
|
||||
depth_16bit[0].save("einstein_depth_16bit.png")
|
||||
```
|
||||
|
||||
The visualization function for depth [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] applies one of [matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]` depth range into an RGB image.
|
||||
The visualization function for depth [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] applies one of [matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]` depth range into an RGB image.
|
||||
With the `Spectral` colormap, pixels with near depth are painted red, and far pixels are assigned blue color.
|
||||
The 16-bit PNG file stores the single channel values mapped linearly from the `[0, 1]` range into `[0, 65535]`.
|
||||
Below are the raw and the visualized predictions; as can be seen, dark areas (mustache) are easier to distinguish in the visualization:
|
||||
@@ -78,20 +78,20 @@ Below are the raw and the visualized predictions; as can be seen, dark areas (mu
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth_16bit.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Predicted depth (16-bit PNG)
|
||||
</figcaption>
|
||||
Predicted depth (16-bit PNG)
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Predicted depth visualization (Spectral)
|
||||
</figcaption>
|
||||
Predicted depth visualization (Spectral)
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Surface Normals Prediction Quick Start
|
||||
|
||||
Load `prs-eth/marigold-normals-lcm-v0-1` checkpoint into `MarigoldNormalsPipeline` pipeline, put the image through the pipeline, and save the predictions:
|
||||
Load `prs-eth/marigold-normals-lcm-v0-1` checkpoint into `MarigoldNormalsPipeline` pipeline, put the image through the pipeline, and save the predictions:
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
@@ -108,8 +108,8 @@ vis = pipe.image_processor.visualize_normals(normals.prediction)
|
||||
vis[0].save("einstein_normals.png")
|
||||
```
|
||||
|
||||
The visualization function for normals [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional prediction with pixel values in the range `[-1, 1]` into an RGB image.
|
||||
The visualization function supports flipping surface normals axes to make the visualization compatible with other choices of the frame of reference.
|
||||
The visualization function for normals [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional prediction with pixel values in the range `[-1, 1]` into an RGB image.
|
||||
The visualization function supports flipping surface normals axes to make the visualization compatible with other choices of the frame of reference.
|
||||
Conceptually, each pixel is painted according to the surface normal vector in the frame of reference, where `X` axis points right, `Y` axis points up, and `Z` axis points at the viewer.
|
||||
Below is the visualized prediction:
|
||||
|
||||
@@ -117,19 +117,19 @@ Below is the visualized prediction:
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_normals.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Predicted surface normals visualization
|
||||
</figcaption>
|
||||
Predicted surface normals visualization
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
In this example, the nose tip almost certainly has a point on the surface, in which the surface normal vector points straight at the viewer, meaning that its coordinates are `[0, 0, 1]`.
|
||||
This vector maps to the RGB `[128, 128, 255]`, which corresponds to the violet-blue color.
|
||||
Similarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the red hue.
|
||||
Points on the shoulders pointing up with a large `Y` promote green color.
|
||||
Similarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the red hue.
|
||||
Points on the shoulders pointing up with a large `Y` promote green color.
|
||||
|
||||
### Speeding up inference
|
||||
|
||||
The above quick start snippets are already optimized for speed: they load the LCM checkpoint, use the `fp16` variant of weights and computation, and perform just one denoising diffusion step.
|
||||
The above quick start snippets are already optimized for speed: they load the LCM checkpoint, use the `fp16` variant of weights and computation, and perform just one denoising diffusion step.
|
||||
The `pipe(image)` call completes in 280ms on RTX 3090 GPU.
|
||||
Internally, the input image is encoded with the Stable Diffusion VAE encoder, then the U-Net performs one denoising step, and finally, the prediction latent is decoded with the VAE decoder into pixel space.
|
||||
In this case, two out of three module calls are dedicated to converting between pixel and latent space of LDM.
|
||||
@@ -144,7 +144,7 @@ Because Marigold's latent space is compatible with the base Stable Diffusion, it
|
||||
).to("cuda")
|
||||
|
||||
+ pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
|
||||
+ "madebyollin/taesd", torch_dtype=torch.float16
|
||||
+ "madebyollin/taesd", torch_dtype=torch.float16
|
||||
+ ).cuda()
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
@@ -175,23 +175,23 @@ With the above speed optimizations, Marigold delivers predictions with more deta
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Marigold LCM fp16 with Tiny AutoEncoder
|
||||
</figcaption>
|
||||
Marigold LCM fp16 with Tiny AutoEncoder
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/bfe7cb56ca1cc0811b328212472350879dfa7f8b/marigold/einstein_depthanything_large.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Depth Anything Large
|
||||
</figcaption>
|
||||
Depth Anything Large
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Maximizing Precision and Ensembling
|
||||
|
||||
Marigold pipelines have a built-in ensembling mechanism combining multiple predictions from different random latents.
|
||||
Marigold pipelines have a built-in ensembling mechanism combining multiple predictions from different random latents.
|
||||
This is a brute-force way of improving the precision of predictions, capitalizing on the generative nature of diffusion.
|
||||
The ensembling path is activated automatically when the `ensemble_size` argument is set greater than `1`.
|
||||
When aiming for maximum precision, it makes sense to adjust `num_inference_steps` simultaneously with `ensemble_size`.
|
||||
The ensembling path is activated automatically when the `ensemble_size` argument is set greater than `1`.
|
||||
When aiming for maximum precision, it makes sense to adjust `num_inference_steps` simultaneously with `ensemble_size`.
|
||||
The recommended values vary across checkpoints but primarily depend on the scheduler type.
|
||||
The effect of ensembling is particularly well-seen with surface normals:
|
||||
|
||||
@@ -226,14 +226,14 @@ vis[0].save("einstein_normals.png")
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_normals.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Surface normals, no ensembling
|
||||
</figcaption>
|
||||
Surface normals, no ensembling
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_normals.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Surface normals, with ensembling
|
||||
</figcaption>
|
||||
Surface normals, with ensembling
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -242,7 +242,7 @@ Such a result is more suitable for precision-sensitive downstream tasks, such as
|
||||
|
||||
## Quantitative Evaluation
|
||||
|
||||
To evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets), follow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values for `num_inference_steps` and `ensemble_size`.
|
||||
To evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets), follow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values for `num_inference_steps` and `ensemble_size`.
|
||||
Optionally seed randomness to ensure reproducibility. Maximizing `batch_size` will deliver maximum device utilization.
|
||||
|
||||
```python
|
||||
@@ -277,7 +277,7 @@ depth = pipe(image, generator=generator, **pipe_kwargs)
|
||||
|
||||
## Using Predictive Uncertainty
|
||||
|
||||
The ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random latents.
|
||||
The ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random latents.
|
||||
As a side effect, it can be used to quantify epistemic (model) uncertainty; simply specify `ensemble_size` greater than 1 and set `output_uncertainty=True`.
|
||||
The resulting uncertainty will be available in the `uncertainty` field of the output.
|
||||
It can be visualized as follows:
|
||||
@@ -305,14 +305,14 @@ uncertainty[0].save("einstein_depth_uncertainty.png")
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_depth_uncertainty.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Depth uncertainty
|
||||
</figcaption>
|
||||
Depth uncertainty
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_normals_uncertainty.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Surface normals uncertainty
|
||||
</figcaption>
|
||||
Surface normals uncertainty
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -322,7 +322,7 @@ The surface normals model is the least confident in fine-grained structures, suc
|
||||
|
||||
## Frame-by-frame Video Processing with Temporal Consistency
|
||||
|
||||
Due to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent initialization.
|
||||
Due to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent initialization.
|
||||
This becomes an obvious drawback compared to traditional end-to-end dense regression networks, as exemplified in the following videos:
|
||||
|
||||
<div class="flex gap-4">
|
||||
@@ -373,7 +373,7 @@ with imageio.get_reader(path_in) as reader:
|
||||
latents = 0.9 * latents + 0.1 * last_frame_latent
|
||||
|
||||
depth = pipe(
|
||||
frame, match_input_resolution=False, latents=latents, output_latent=True,
|
||||
frame, match_input_resolution=False, latents=latents, output_latent=True
|
||||
)
|
||||
last_frame_latent = depth.latent
|
||||
out.append(pipe.image_processor.visualize_depth(depth.prediction)[0])
|
||||
@@ -381,7 +381,7 @@ with imageio.get_reader(path_in) as reader:
|
||||
diffusers.utils.export_to_gif(out, path_out, fps=reader.get_meta_data()['fps'])
|
||||
```
|
||||
|
||||
Here, the diffusion process starts from the given computed latent.
|
||||
Here, the diffusion process starts from the given computed latent.
|
||||
The pipeline sets `output_latent=True` to access `out.latent` and computes its contribution to the next frame's latent initialization.
|
||||
The result is much more stable now:
|
||||
|
||||
@@ -396,4 +396,71 @@ The result is much more stable now:
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Hopefully, you will find Marigold useful for solving your downstream tasks, be it a part of a more broad generative workflow, or a broader perception task, such as 3D reconstruction.
|
||||
## Marigold for ControlNet
|
||||
|
||||
A very common application for depth prediction with diffusion models comes in conjunction with ControlNet.
|
||||
Depth crispness plays a crucial role in obtaining high-quality results from ControlNet.
|
||||
As seen in comparisons with other methods above, Marigold excels at that task.
|
||||
The snippet below demonstrates how to load an image, compute depth, and pass it into ControlNet in a compatible format:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import diffusers
|
||||
|
||||
device = "cuda"
|
||||
generator = torch.Generator(device=device).manual_seed(2024)
|
||||
image = diffusers.utils.load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_source.png"
|
||||
)
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-lcm-v1-0", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
|
||||
depth_image = pipe(image, generator=generator).prediction
|
||||
depth_image = pipe.image_processor.visualize_depth(depth_image, color_map="binary")
|
||||
depth_image[0].save("motorcycle_controlnet_depth.png")
|
||||
|
||||
controlnet = diffusers.ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipe = diffusers.StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
"SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, variant="fp16", controlnet=controlnet
|
||||
).to("cuda")
|
||||
pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
|
||||
|
||||
controlnet_out = pipe(
|
||||
prompt="high quality photo of a sports bike, city",
|
||||
negative_prompt="",
|
||||
guidance_scale=6.5,
|
||||
num_inference_steps=25,
|
||||
image=depth_image,
|
||||
controlnet_conditioning_scale=0.7,
|
||||
control_guidance_end=0.7,
|
||||
generator=generator,
|
||||
).images
|
||||
controlnet_out[0].save("motorcycle_controlnet_out.png")
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div style="flex: 1 1 33%; max-width: 33%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_source.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Input image
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 33%; max-width: 33%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/8e61e31f9feb7756c0404ceff26f3f0e5d3fe610/marigold/motorcycle_controlnet_depth.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Depth in the format compatible with ControlNet
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 33%; max-width: 33%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/8e61e31f9feb7756c0404ceff26f3f0e5d3fe610/marigold/motorcycle_controlnet_out.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
ControlNet generation, conditioned on depth and prompt: "high quality photo of a sports bike, city"
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Hopefully, you will find Marigold useful for solving your downstream tasks, be it a part of a more broad generative workflow, or a perception task, such as 3D reconstruction.
|
||||
|
||||
@@ -245,9 +245,9 @@ class MarigoldImageProcessor(ConfigMixin):
|
||||
) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the
|
||||
behavior of matplotlib.colormaps, but allows the user to use the most discriminative color map "Spectral"
|
||||
without having to install or import matplotlib. For all other cases, the function will attempt to use the
|
||||
native implementation.
|
||||
behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral",
|
||||
"binary") without having to install or import matplotlib. For all other cases, the function will attempt to use
|
||||
the native implementation.
|
||||
|
||||
Args:
|
||||
image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor.
|
||||
@@ -255,7 +255,7 @@ class MarigoldImageProcessor(ConfigMixin):
|
||||
bytes: Whether to return the output as uint8 or floating point image.
|
||||
_force_method:
|
||||
Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom
|
||||
implementation of the "Spectral" color map (`"custom"`), or rely on autodetection (`None`, default).
|
||||
implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default).
|
||||
|
||||
Returns:
|
||||
An RGB-colorized tensor corresponding to the input image.
|
||||
@@ -265,6 +265,26 @@ class MarigoldImageProcessor(ConfigMixin):
|
||||
if _force_method not in (None, "matplotlib", "custom"):
|
||||
raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.")
|
||||
|
||||
supported_cmaps = {
|
||||
"binary": [
|
||||
(1.0, 1.0, 1.0),
|
||||
(0.0, 0.0, 0.0),
|
||||
],
|
||||
"Spectral": [ # Taken from matplotlib/_cm.py
|
||||
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0]
|
||||
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746),
|
||||
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157),
|
||||
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508),
|
||||
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623),
|
||||
(1.0, 1.0, 0.74901960784313726),
|
||||
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901),
|
||||
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083),
|
||||
(0.4, 0.76078431372549016, 0.6470588235294118),
|
||||
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533),
|
||||
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1]
|
||||
],
|
||||
}
|
||||
|
||||
def method_matplotlib(image, cmap, bytes=False):
|
||||
if is_matplotlib_available():
|
||||
import matplotlib
|
||||
@@ -298,24 +318,19 @@ class MarigoldImageProcessor(ConfigMixin):
|
||||
else:
|
||||
image = image.float()
|
||||
|
||||
if cmap != "Spectral":
|
||||
raise ValueError("Only 'Spectral' color map is available without installing matplotlib.")
|
||||
is_cmap_reversed = cmap.endswith("_r")
|
||||
if is_cmap_reversed:
|
||||
cmap = cmap[:-2]
|
||||
|
||||
_Spectral_data = ( # Taken from matplotlib/_cm.py
|
||||
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0]
|
||||
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746),
|
||||
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157),
|
||||
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508),
|
||||
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623),
|
||||
(1.0, 1.0, 0.74901960784313726),
|
||||
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901),
|
||||
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083),
|
||||
(0.4, 0.76078431372549016, 0.6470588235294118),
|
||||
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533),
|
||||
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1]
|
||||
)
|
||||
if cmap not in supported_cmaps:
|
||||
raise ValueError(
|
||||
f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib."
|
||||
)
|
||||
|
||||
cmap = torch.tensor(_Spectral_data, dtype=torch.float, device=image.device) # [K,3]
|
||||
cmap = supported_cmaps[cmap]
|
||||
if is_cmap_reversed:
|
||||
cmap = cmap[::-1]
|
||||
cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3]
|
||||
K = cmap.shape[0]
|
||||
|
||||
pos = image.clamp(min=0, max=1) * (K - 1)
|
||||
|
||||
Reference in New Issue
Block a user