mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix overflow and dtype handling in rgblike_to_depthmap (NumPy + PyTorch) (#12546)
* Fix overflow in rgblike_to_depthmap by safe dtype casting (torch & NumPy) * Fix: store original dtype and cast back after safe computation * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b3e9dfced7
commit
e4393fa613
@@ -1045,16 +1045,39 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
r"""
|
||||
Convert an RGB-like depth image to a depth map.
|
||||
|
||||
Args:
|
||||
image (`Union[np.ndarray, torch.Tensor]`):
|
||||
The RGB-like depth image to convert.
|
||||
|
||||
Returns:
|
||||
`Union[np.ndarray, torch.Tensor]`:
|
||||
The corresponding depth map.
|
||||
"""
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
# 1. Cast the tensor to a larger integer type (e.g., int32)
|
||||
# to safely perform the multiplication by 256.
|
||||
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
|
||||
# 3. Cast the final result to the desired depth map type (uint16) if needed
|
||||
# before returning, though leaving it as int32/int64 is often safer
|
||||
# for return value from a library function.
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.to(torch.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# You may want to cast the final result to uint16, but casting to a
|
||||
# larger int type (like int32) is sufficient to fix the overflow.
|
||||
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.to(original_dtype)
|
||||
|
||||
elif isinstance(image, np.ndarray):
|
||||
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.astype(np.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.astype(original_dtype)
|
||||
else:
|
||||
raise TypeError("Input image must be a torch.Tensor or np.ndarray")
|
||||
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user