From de1fb4b615b9941e77602d132a36795a6f2d2961 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 24 Jul 2025 08:31:47 +0530 Subject: [PATCH] update --- src/diffusers/quantizers/gguf/utils.py | 8 +++- tests/quantization/gguf/test_gguf.py | 58 ++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 31f6ec3e73..edbc60abf5 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -12,8 +12,8 @@ # # See the License for the specific language governing permissions and # # limitations under the License. - import inspect +import os from contextlib import nullcontext import gguf @@ -29,7 +29,11 @@ if is_accelerate_available(): from accelerate.hooks import add_hook_to_module, remove_hook_from_module -can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 +can_use_cuda_kernels = ( + os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "true").lower() in ["1", "true", "yes"] + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 7 +) if can_use_cuda_kernels and is_kernels_available(): from kernels import get_kernel diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 0d786de7e7..aa558b3e82 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import ( nightly, numpy_cosine_similarity_distance, require_accelerate, + require_accelerator, require_big_accelerator, require_gguf_version_greater_or_equal, require_peft_backend, @@ -37,11 +38,68 @@ from diffusers.utils.testing_utils import ( if is_gguf_available(): + import gguf + from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter enable_full_determinism() +@nightly +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFCudaKernelsTests(unittest.TestCase): + def setUp(self): + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_cuda_kernels_vs_native(self): + if torch_device != "cuda": + self.skipTest("CUDA kernels test requires CUDA device") + + from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels + + if not can_use_cuda_kernels: + self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)") + + test_quant_types = ["Q4_0", "Q4_K"] + test_shape = (1, 64, 512) # batch, seq_len, hidden_dim + compute_dtype = torch.bfloat16 + + for quant_type in test_quant_types: + qtype = getattr(gguf.GGMLQuantizationType, quant_type) + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + + in_features, out_features = 512, 512 + total_elements = in_features * out_features + n_blocks = total_elements // block_size + weight_bytes = n_blocks * type_size + + torch.manual_seed(42) + weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device) + weight = GGUFParameter(weight_data, quant_type=qtype) + + x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device) + + linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype) + linear.weight = weight + linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype)) + linear = linear.to(torch_device) + + with torch.no_grad(): + output_native = linear.forward_native(x) + output_cuda = linear.forward_cuda(x) + + # Compare outputs + max_diff = torch.abs(output_cuda - output_native).max() + assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output" + + @nightly @require_big_accelerator @require_accelerate