1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sdnq/packed_float.py
vladmandic 4e8b0f83b4 lint
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-01 16:33:49 +01:00

103 lines
3.1 KiB
Python

import torch
from .common import dtype_dict
from .packed_int import pack_int_asymetric, unpack_int_asymetric
float_bits_to_uint_dict = {
1: "uint1",
2: "uint2",
3: "uint3",
4: "uint4",
5: "uint5",
6: "uint6",
7: "uint7",
}
def pack_float(x: torch.FloatTensor, weights_dtype: str) -> torch.Tensor:
exponent_bits = dtype_dict[weights_dtype]["exponent"]
mantissa_bits = dtype_dict[weights_dtype]["mantissa"]
total_bits = dtype_dict[weights_dtype]["num_bits"]
if dtype_dict[weights_dtype]["is_unsigned"]:
sign_mask = (1 << (total_bits-1)) # pylint: disable=superfluous-parens
else:
sign_mask = (1 << (total_bits-1)) + (1 << (total_bits-2))
mantissa_difference = 23 - mantissa_bits
exponent_difference = 8 - exponent_bits
mantissa_mask = (1 << mantissa_difference) # pylint: disable=superfluous-parens
x = x.to(dtype=torch.float32).view(torch.int32)
x = torch.where(
torch.gt(
torch.bitwise_and(x, -(1 << (mantissa_difference-4)) & ~(-mantissa_mask)),
(1 << (mantissa_difference-1)),
),
torch.add(x, mantissa_mask),
x,
)
x = torch.where(torch.lt(x.view(torch.float32).abs(), dtype_dict[weights_dtype]["min_normal"]), 0, x)
x = torch.bitwise_right_shift(x, mantissa_difference)
x = torch.bitwise_and(
torch.bitwise_or(
torch.bitwise_and(torch.bitwise_right_shift(x, exponent_difference), sign_mask),
torch.bitwise_and(x, ~sign_mask),
),
~(-(1 << total_bits)),
).view(torch.uint32)
if total_bits < 8:
x = pack_int_asymetric(x, float_bits_to_uint_dict[total_bits])
else:
x = x.to(dtype=dtype_dict[weights_dtype]["storage_dtype"])
return x
def unpack_float(x: torch.Tensor, shape: torch.Size, weights_dtype: str) -> torch.FloatTensor:
exponent_bits = dtype_dict[weights_dtype]["exponent"]
mantissa_bits = dtype_dict[weights_dtype]["mantissa"]
total_bits = dtype_dict[weights_dtype]["num_bits"]
if dtype_dict[weights_dtype]["is_unsigned"]:
sign_mask = (1 << (total_bits-1)) # pylint: disable=superfluous-parens
else:
sign_mask = (1 << (total_bits-1)) + (1 << (total_bits-2))
mantissa_difference = 23 - mantissa_bits
exponent_difference = 8 - exponent_bits
if total_bits < 8:
x = unpack_int_asymetric(x, shape, float_bits_to_uint_dict[total_bits])
x = x.to(dtype=torch.uint32).view(torch.int32)
x = torch.bitwise_left_shift(
torch.bitwise_or(
torch.bitwise_left_shift(torch.bitwise_and(x, sign_mask), exponent_difference),
torch.bitwise_and(x, ~sign_mask),
),
mantissa_difference,
)
x = torch.bitwise_or(
x,
torch.bitwise_and(
torch.bitwise_right_shift(
-torch.bitwise_and(torch.bitwise_not(x), 1073741824),
exponent_difference,
),
1065353216,
),
)
overflow_mask = (~(-(1 << (22 + exponent_bits))) | -1073741824)
x = torch.where(torch.bitwise_and(x, overflow_mask).to(dtype=torch.bool), x, 0)
x = x.view(torch.float32)
return x