mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Community pipeline] Marigold depth estimation update -- align with marigold v0.1.5 (#7524)
* add resample option; check denoise_step; update ckpt path * Add seeding in pipeline to increase reproducibility * fix typo * fix typo
This commit is contained in:
@@ -85,14 +85,25 @@ This depth estimation pipeline processes a single input image through multiple d
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# Original DDIM version (higher quality)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"Bingxin/Marigold",
|
||||
"prs-eth/marigold-v1-0",
|
||||
custom_pipeline="marigold_depth_estimation"
|
||||
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
|
||||
# variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
|
||||
)
|
||||
|
||||
# (New) LCM version (faster speed)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"prs-eth/marigold-lcm-v1-0",
|
||||
custom_pipeline="marigold_depth_estimation"
|
||||
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
|
||||
# variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
|
||||
)
|
||||
|
||||
pipe.to("cuda")
|
||||
@@ -101,12 +112,21 @@ img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_e
|
||||
image: Image.Image = load_image(img_path_or_url)
|
||||
|
||||
pipeline_output = pipe(
|
||||
image, # Input image.
|
||||
image, # Input image.
|
||||
# ----- recommended setting for DDIM version -----
|
||||
# denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
|
||||
# ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
|
||||
# ------------------------------------------------
|
||||
|
||||
# ----- recommended setting for LCM version ------
|
||||
# denoising_steps=4,
|
||||
# ensemble_size=5,
|
||||
# -------------------------------------------------
|
||||
|
||||
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
|
||||
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
|
||||
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
|
||||
# seed=2024, # (optional) Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing --batch_size 1 helps to increase reproducibility. To ensure full reproducibility, deterministic mode needs to be used.
|
||||
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation.
|
||||
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Union
|
||||
|
||||
@@ -25,6 +26,7 @@ import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.Image import Resampling
|
||||
from scipy.optimize import minimize
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from tqdm.auto import tqdm
|
||||
@@ -34,13 +36,14 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
LCMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.28.0.dev0")
|
||||
check_min_version("0.25.0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
@@ -61,6 +64,19 @@ class MarigoldDepthOutput(BaseOutput):
|
||||
uncertainty: Union[None, np.ndarray]
|
||||
|
||||
|
||||
def get_pil_resample_method(method_str: str) -> Resampling:
|
||||
resample_method_dic = {
|
||||
"bilinear": Resampling.BILINEAR,
|
||||
"bicubic": Resampling.BICUBIC,
|
||||
"nearest": Resampling.NEAREST,
|
||||
}
|
||||
resample_method = resample_method_dic.get(method_str, None)
|
||||
if resample_method is None:
|
||||
raise ValueError(f"Unknown resampling method: {resample_method}")
|
||||
else:
|
||||
return resample_method
|
||||
|
||||
|
||||
class MarigoldPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
||||
@@ -113,7 +129,9 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
ensemble_size: int = 10,
|
||||
processing_res: int = 768,
|
||||
match_input_res: bool = True,
|
||||
resample_method: str = "bilinear",
|
||||
batch_size: int = 0,
|
||||
seed: Union[int, None] = None,
|
||||
color_map: str = "Spectral",
|
||||
show_progress_bar: bool = True,
|
||||
ensemble_kwargs: Dict = None,
|
||||
@@ -129,7 +147,9 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
If set to 0: will not resize at all.
|
||||
match_input_res (`bool`, *optional*, defaults to `True`):
|
||||
Resize depth prediction to match input resolution.
|
||||
Only valid if `limit_input_res` is not None.
|
||||
Only valid if `processing_res` > 0.
|
||||
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
||||
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
||||
denoising_steps (`int`, *optional*, defaults to `10`):
|
||||
Number of diffusion denoising steps (DDIM) during inference.
|
||||
ensemble_size (`int`, *optional*, defaults to `10`):
|
||||
@@ -137,6 +157,8 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to `0`):
|
||||
Inference batch size, no bigger than `num_ensemble`.
|
||||
If set to 0, the script will automatically decide the proper batch size.
|
||||
seed (`int`, *optional*, defaults to `None`)
|
||||
Reproducibility seed.
|
||||
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
|
||||
@@ -146,8 +168,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
Returns:
|
||||
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
||||
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
||||
- **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
|
||||
values in [0, 1]. None if `color_map` is `None`
|
||||
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
|
||||
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
||||
coming from ensembling. None if `ensemble_size = 1`
|
||||
"""
|
||||
@@ -158,13 +179,21 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
if not match_input_res:
|
||||
assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
|
||||
assert processing_res >= 0
|
||||
assert denoising_steps >= 1
|
||||
assert ensemble_size >= 1
|
||||
|
||||
# Check if denoising step is reasonable
|
||||
self._check_inference_step(denoising_steps)
|
||||
|
||||
resample_method: Resampling = get_pil_resample_method(resample_method)
|
||||
|
||||
# ----------------- Image Preprocess -----------------
|
||||
# Resize image
|
||||
if processing_res > 0:
|
||||
input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res)
|
||||
input_image = self.resize_max_res(
|
||||
input_image,
|
||||
max_edge_resolution=processing_res,
|
||||
resample_method=resample_method,
|
||||
)
|
||||
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
|
||||
input_image = input_image.convert("RGB")
|
||||
image = np.asarray(input_image)
|
||||
@@ -203,9 +232,10 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
rgb_in=batched_img,
|
||||
num_inference_steps=denoising_steps,
|
||||
show_pbar=show_progress_bar,
|
||||
seed=seed,
|
||||
)
|
||||
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
||||
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
|
||||
depth_pred_ls.append(depth_pred_raw.detach())
|
||||
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
|
||||
torch.cuda.empty_cache() # clear vram cache for ensembling
|
||||
|
||||
# ----------------- Test-time ensembling -----------------
|
||||
@@ -227,7 +257,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
# Resize back to original resolution
|
||||
if match_input_res:
|
||||
pred_img = Image.fromarray(depth_pred)
|
||||
pred_img = pred_img.resize(input_size)
|
||||
pred_img = pred_img.resize(input_size, resample=resample_method)
|
||||
depth_pred = np.asarray(pred_img)
|
||||
|
||||
# Clip output range
|
||||
@@ -243,12 +273,32 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
else:
|
||||
depth_colored_img = None
|
||||
|
||||
return MarigoldDepthOutput(
|
||||
depth_np=depth_pred,
|
||||
depth_colored=depth_colored_img,
|
||||
uncertainty=pred_uncert,
|
||||
)
|
||||
|
||||
def _check_inference_step(self, n_step: int):
|
||||
"""
|
||||
Check if denoising step is reasonable
|
||||
Args:
|
||||
n_step (`int`): denoising steps
|
||||
"""
|
||||
assert n_step >= 1
|
||||
|
||||
if isinstance(self.scheduler, DDIMScheduler):
|
||||
if n_step < 10:
|
||||
logging.warning(
|
||||
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
|
||||
)
|
||||
elif isinstance(self.scheduler, LCMScheduler):
|
||||
if not 1 <= n_step <= 4:
|
||||
logging.warning(f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
|
||||
|
||||
def _encode_empty_text(self):
|
||||
"""
|
||||
Encode text embedding for empty prompt.
|
||||
@@ -265,7 +315,13 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor:
|
||||
def single_infer(
|
||||
self,
|
||||
rgb_in: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
seed: Union[int, None],
|
||||
show_pbar: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform an individual depth prediction without ensembling.
|
||||
|
||||
@@ -286,10 +342,20 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps # [T]
|
||||
|
||||
# Encode image
|
||||
rgb_latent = self._encode_rgb(rgb_in)
|
||||
rgb_latent = self.encode_rgb(rgb_in)
|
||||
|
||||
# Initial depth map (noise)
|
||||
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w]
|
||||
if seed is None:
|
||||
rand_num_generator = None
|
||||
else:
|
||||
rand_num_generator = torch.Generator(device=device)
|
||||
rand_num_generator.manual_seed(seed)
|
||||
depth_latent = torch.randn(
|
||||
rgb_latent.shape,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
generator=rand_num_generator,
|
||||
) # [B, 4, h, w]
|
||||
|
||||
# Batched empty text embedding
|
||||
if self.empty_text_embed is None:
|
||||
@@ -314,9 +380,9 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
|
||||
torch.cuda.empty_cache()
|
||||
depth = self._decode_depth(depth_latent)
|
||||
depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample
|
||||
|
||||
depth = self.decode_depth(depth_latent)
|
||||
|
||||
# clip prediction
|
||||
depth = torch.clip(depth, -1.0, 1.0)
|
||||
@@ -325,7 +391,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
|
||||
return depth
|
||||
|
||||
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
||||
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode RGB image into latent.
|
||||
|
||||
@@ -344,7 +410,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
rgb_latent = mean * self.rgb_latent_scale_factor
|
||||
return rgb_latent
|
||||
|
||||
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
||||
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Decode depth latent into depth map.
|
||||
|
||||
@@ -365,7 +431,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
return depth_mean
|
||||
|
||||
@staticmethod
|
||||
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
||||
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
|
||||
"""
|
||||
Resize image to limit maximum edge length while keeping aspect ratio.
|
||||
|
||||
@@ -374,6 +440,8 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
Image to be resized.
|
||||
max_edge_resolution (`int`):
|
||||
Maximum edge length (pixel).
|
||||
resample_method (`PIL.Image.Resampling`):
|
||||
Resampling method used to resize images.
|
||||
|
||||
Returns:
|
||||
`Image.Image`: Resized image.
|
||||
@@ -384,7 +452,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
new_width = int(original_width * downscale_factor)
|
||||
new_height = int(original_height * downscale_factor)
|
||||
|
||||
resized_img = img.resize((new_width, new_height))
|
||||
resized_img = img.resize((new_width, new_height), resample=resample_method)
|
||||
return resized_img
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user