diff --git a/examples/community/README.md b/examples/community/README.md index 87c764ac7b..188eee41c0 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -85,14 +85,25 @@ This depth estimation pipeline processes a single input image through multiple d ```python import numpy as np +import torch from PIL import Image from diffusers import DiffusionPipeline from diffusers.utils import load_image +# Original DDIM version (higher quality) pipe = DiffusionPipeline.from_pretrained( - "Bingxin/Marigold", + "prs-eth/marigold-v1-0", custom_pipeline="marigold_depth_estimation" # torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float). + # variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint +) + +# (New) LCM version (faster speed) +pipe = DiffusionPipeline.from_pretrained( + "prs-eth/marigold-lcm-v1-0", + custom_pipeline="marigold_depth_estimation" + # torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float). + # variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint ) pipe.to("cuda") @@ -101,12 +112,21 @@ img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_e image: Image.Image = load_image(img_path_or_url) pipeline_output = pipe( - image, # Input image. + image, # Input image. + # ----- recommended setting for DDIM version ----- # denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10. # ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10. + # ------------------------------------------------ + + # ----- recommended setting for LCM version ------ + # denoising_steps=4, + # ensemble_size=5, + # ------------------------------------------------- + # processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768. # match_input_res=True, # (optional) Resize depth prediction to match input resolution. # batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0. + # seed=2024, # (optional) Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing --batch_size 1 helps to increase reproducibility. To ensure full reproducibility, deterministic mode needs to be used. # color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation. # show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress. ) diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py index 72a17dfc70..ef1b45b942 100644 --- a/examples/community/marigold_depth_estimation.py +++ b/examples/community/marigold_depth_estimation.py @@ -18,6 +18,7 @@ # -------------------------------------------------------------------------- +import logging import math from typing import Dict, Union @@ -25,6 +26,7 @@ import matplotlib import numpy as np import torch from PIL import Image +from PIL.Image import Resampling from scipy.optimize import minimize from torch.utils.data import DataLoader, TensorDataset from tqdm.auto import tqdm @@ -34,13 +36,14 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, DiffusionPipeline, + LCMScheduler, UNet2DConditionModel, ) from diffusers.utils import BaseOutput, check_min_version # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.28.0.dev0") +check_min_version("0.25.0") class MarigoldDepthOutput(BaseOutput): @@ -61,6 +64,19 @@ class MarigoldDepthOutput(BaseOutput): uncertainty: Union[None, np.ndarray] +def get_pil_resample_method(method_str: str) -> Resampling: + resample_method_dic = { + "bilinear": Resampling.BILINEAR, + "bicubic": Resampling.BICUBIC, + "nearest": Resampling.NEAREST, + } + resample_method = resample_method_dic.get(method_str, None) + if resample_method is None: + raise ValueError(f"Unknown resampling method: {resample_method}") + else: + return resample_method + + class MarigoldPipeline(DiffusionPipeline): """ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. @@ -113,7 +129,9 @@ class MarigoldPipeline(DiffusionPipeline): ensemble_size: int = 10, processing_res: int = 768, match_input_res: bool = True, + resample_method: str = "bilinear", batch_size: int = 0, + seed: Union[int, None] = None, color_map: str = "Spectral", show_progress_bar: bool = True, ensemble_kwargs: Dict = None, @@ -129,7 +147,9 @@ class MarigoldPipeline(DiffusionPipeline): If set to 0: will not resize at all. match_input_res (`bool`, *optional*, defaults to `True`): Resize depth prediction to match input resolution. - Only valid if `limit_input_res` is not None. + Only valid if `processing_res` > 0. + resample_method: (`str`, *optional*, defaults to `bilinear`): + Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. denoising_steps (`int`, *optional*, defaults to `10`): Number of diffusion denoising steps (DDIM) during inference. ensemble_size (`int`, *optional*, defaults to `10`): @@ -137,6 +157,8 @@ class MarigoldPipeline(DiffusionPipeline): batch_size (`int`, *optional*, defaults to `0`): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. + seed (`int`, *optional*, defaults to `None`) + Reproducibility seed. show_progress_bar (`bool`, *optional*, defaults to `True`): Display a progress bar of diffusion denoising. color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): @@ -146,8 +168,7 @@ class MarigoldPipeline(DiffusionPipeline): Returns: `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] - - **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and - values in [0, 1]. None if `color_map` is `None` + - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None` - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. None if `ensemble_size = 1` """ @@ -158,13 +179,21 @@ class MarigoldPipeline(DiffusionPipeline): if not match_input_res: assert processing_res is not None, "Value error: `resize_output_back` is only valid with " assert processing_res >= 0 - assert denoising_steps >= 1 assert ensemble_size >= 1 + # Check if denoising step is reasonable + self._check_inference_step(denoising_steps) + + resample_method: Resampling = get_pil_resample_method(resample_method) + # ----------------- Image Preprocess ----------------- # Resize image if processing_res > 0: - input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res) + input_image = self.resize_max_res( + input_image, + max_edge_resolution=processing_res, + resample_method=resample_method, + ) # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel input_image = input_image.convert("RGB") image = np.asarray(input_image) @@ -203,9 +232,10 @@ class MarigoldPipeline(DiffusionPipeline): rgb_in=batched_img, num_inference_steps=denoising_steps, show_pbar=show_progress_bar, + seed=seed, ) - depth_pred_ls.append(depth_pred_raw.detach().clone()) - depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() + depth_pred_ls.append(depth_pred_raw.detach()) + depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze() torch.cuda.empty_cache() # clear vram cache for ensembling # ----------------- Test-time ensembling ----------------- @@ -227,7 +257,7 @@ class MarigoldPipeline(DiffusionPipeline): # Resize back to original resolution if match_input_res: pred_img = Image.fromarray(depth_pred) - pred_img = pred_img.resize(input_size) + pred_img = pred_img.resize(input_size, resample=resample_method) depth_pred = np.asarray(pred_img) # Clip output range @@ -243,12 +273,32 @@ class MarigoldPipeline(DiffusionPipeline): depth_colored_img = Image.fromarray(depth_colored_hwc) else: depth_colored_img = None + return MarigoldDepthOutput( depth_np=depth_pred, depth_colored=depth_colored_img, uncertainty=pred_uncert, ) + def _check_inference_step(self, n_step: int): + """ + Check if denoising step is reasonable + Args: + n_step (`int`): denoising steps + """ + assert n_step >= 1 + + if isinstance(self.scheduler, DDIMScheduler): + if n_step < 10: + logging.warning( + f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference." + ) + elif isinstance(self.scheduler, LCMScheduler): + if not 1 <= n_step <= 4: + logging.warning(f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.") + else: + raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") + def _encode_empty_text(self): """ Encode text embedding for empty prompt. @@ -265,7 +315,13 @@ class MarigoldPipeline(DiffusionPipeline): self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) @torch.no_grad() - def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor: + def single_infer( + self, + rgb_in: torch.Tensor, + num_inference_steps: int, + seed: Union[int, None], + show_pbar: bool, + ) -> torch.Tensor: """ Perform an individual depth prediction without ensembling. @@ -286,10 +342,20 @@ class MarigoldPipeline(DiffusionPipeline): timesteps = self.scheduler.timesteps # [T] # Encode image - rgb_latent = self._encode_rgb(rgb_in) + rgb_latent = self.encode_rgb(rgb_in) # Initial depth map (noise) - depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w] + if seed is None: + rand_num_generator = None + else: + rand_num_generator = torch.Generator(device=device) + rand_num_generator.manual_seed(seed) + depth_latent = torch.randn( + rgb_latent.shape, + device=device, + dtype=self.dtype, + generator=rand_num_generator, + ) # [B, 4, h, w] # Batched empty text embedding if self.empty_text_embed is None: @@ -314,9 +380,9 @@ class MarigoldPipeline(DiffusionPipeline): noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w] # compute the previous noisy sample x_t -> x_t-1 - depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample - torch.cuda.empty_cache() - depth = self._decode_depth(depth_latent) + depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample + + depth = self.decode_depth(depth_latent) # clip prediction depth = torch.clip(depth, -1.0, 1.0) @@ -325,7 +391,7 @@ class MarigoldPipeline(DiffusionPipeline): return depth - def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: + def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: """ Encode RGB image into latent. @@ -344,7 +410,7 @@ class MarigoldPipeline(DiffusionPipeline): rgb_latent = mean * self.rgb_latent_scale_factor return rgb_latent - def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: + def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: """ Decode depth latent into depth map. @@ -365,7 +431,7 @@ class MarigoldPipeline(DiffusionPipeline): return depth_mean @staticmethod - def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: + def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image: """ Resize image to limit maximum edge length while keeping aspect ratio. @@ -374,6 +440,8 @@ class MarigoldPipeline(DiffusionPipeline): Image to be resized. max_edge_resolution (`int`): Maximum edge length (pixel). + resample_method (`PIL.Image.Resampling`): + Resampling method used to resize images. Returns: `Image.Image`: Resized image. @@ -384,7 +452,7 @@ class MarigoldPipeline(DiffusionPipeline): new_width = int(original_width * downscale_factor) new_height = int(original_height * downscale_factor) - resized_img = img.resize((new_width, new_height)) + resized_img = img.resize((new_width, new_height), resample=resample_method) return resized_img @staticmethod