mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
64 lines
2.6 KiB
Python
64 lines
2.6 KiB
Python
import cv2
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from modules import devices, masking
|
|
from modules.shared import opts
|
|
|
|
|
|
class DepthProDetector:
|
|
"""Apple DepthPro detector (aligned with Depth Anything style)."""
|
|
|
|
def __init__(self, model, processor):
|
|
self.model = model
|
|
self.processor = processor
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_or_path: str = "apple/DepthPro-hf", cache_dir: str = None, local_files_only = False) -> "DepthProDetector":
|
|
from transformers import AutoImageProcessor, DepthProForDepthEstimation
|
|
|
|
processor = AutoImageProcessor.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, local_files_only=local_files_only)
|
|
model = DepthProForDepthEstimation.from_pretrained(
|
|
pretrained_model_or_path,
|
|
cache_dir=cache_dir,
|
|
local_files_only=local_files_only,
|
|
).to(devices.device).eval()
|
|
return cls(model, processor)
|
|
|
|
def __call__(self, image, color_map: str = "none", output_type: str = "pil"):
|
|
self.model.to(devices.device)
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
h, w = image.shape[:2]
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
pil_image = Image.fromarray(image_rgb)
|
|
|
|
inputs = self.processor(images=pil_image, return_tensors="pt")
|
|
inputs = {k: v.to(devices.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
|
|
with devices.inference_context():
|
|
outputs = self.model(**inputs)
|
|
results = self.processor.post_process_depth_estimation(outputs, target_sizes=[(h, w)])
|
|
depth_tensor = results[0]["predicted_depth"].to(devices.device, dtype=torch.float32)
|
|
|
|
if opts.control_move_processor:
|
|
self.model.to("cpu")
|
|
|
|
depth_tensor = F.interpolate(depth_tensor[None, None], size=(h, w), mode="bilinear", align_corners=False)[0, 0]
|
|
depth_tensor = 1.0 / torch.clamp(depth_tensor, min=1e-6)
|
|
depth_tensor -= depth_tensor.min()
|
|
depth_max = depth_tensor.max()
|
|
if depth_max > 0:
|
|
depth_tensor /= depth_max
|
|
depth = (depth_tensor * 255.0).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
|
|
|
if color_map != "none":
|
|
colormap_key = color_map if color_map in masking.COLORMAP else "inferno"
|
|
depth = cv2.applyColorMap(depth, masking.COLORMAP.index(colormap_key))[:, :, ::-1]
|
|
if output_type == "pil":
|
|
mode = "RGB" if depth.ndim == 3 else "L"
|
|
depth = Image.fromarray(depth, mode=mode)
|
|
return depth
|