1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-07-24 08:31:47 +05:30
parent e46571a7aa
commit de1fb4b615
2 changed files with 64 additions and 2 deletions

View File

@@ -12,8 +12,8 @@
# # See the License for the specific language governing permissions and
# # limitations under the License.
import inspect
import os
from contextlib import nullcontext
import gguf
@@ -29,7 +29,11 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7
can_use_cuda_kernels = (
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "true").lower() in ["1", "true", "yes"]
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel

View File

@@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import (
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_accelerator,
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_peft_backend,
@@ -37,11 +38,68 @@ from diffusers.utils.testing_utils import (
if is_gguf_available():
import gguf
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
enable_full_determinism()
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
class GGUFCudaKernelsTests(unittest.TestCase):
def setUp(self):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
def test_cuda_kernels_vs_native(self):
if torch_device != "cuda":
self.skipTest("CUDA kernels test requires CUDA device")
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
if not can_use_cuda_kernels:
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
test_quant_types = ["Q4_0", "Q4_K"]
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
compute_dtype = torch.bfloat16
for quant_type in test_quant_types:
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
in_features, out_features = 512, 512
total_elements = in_features * out_features
n_blocks = total_elements // block_size
weight_bytes = n_blocks * type_size
torch.manual_seed(42)
weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device)
weight = GGUFParameter(weight_data, quant_type=qtype)
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
linear.weight = weight
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
linear = linear.to(torch_device)
with torch.no_grad():
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
# Compare outputs
max_diff = torch.abs(output_cuda - output_native).max()
assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output"
@nightly
@require_big_accelerator
@require_accelerate