mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -12,8 +12,8 @@
|
||||
# # See the License for the specific language governing permissions and
|
||||
# # limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import gguf
|
||||
@@ -29,7 +29,11 @@ if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
|
||||
|
||||
can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7
|
||||
can_use_cuda_kernels = (
|
||||
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "true").lower() in ["1", "true", "yes"]
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 7
|
||||
)
|
||||
if can_use_cuda_kernels and is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_accelerator,
|
||||
require_big_accelerator,
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_peft_backend,
|
||||
@@ -37,11 +38,68 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
|
||||
if is_gguf_available():
|
||||
import gguf
|
||||
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
@require_gguf_version_greater_or_equal("0.10.0")
|
||||
class GGUFCudaKernelsTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_cuda_kernels_vs_native(self):
|
||||
if torch_device != "cuda":
|
||||
self.skipTest("CUDA kernels test requires CUDA device")
|
||||
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
|
||||
|
||||
if not can_use_cuda_kernels:
|
||||
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
|
||||
|
||||
test_quant_types = ["Q4_0", "Q4_K"]
|
||||
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
|
||||
compute_dtype = torch.bfloat16
|
||||
|
||||
for quant_type in test_quant_types:
|
||||
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
||||
|
||||
in_features, out_features = 512, 512
|
||||
total_elements = in_features * out_features
|
||||
n_blocks = total_elements // block_size
|
||||
weight_bytes = n_blocks * type_size
|
||||
|
||||
torch.manual_seed(42)
|
||||
weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device)
|
||||
weight = GGUFParameter(weight_data, quant_type=qtype)
|
||||
|
||||
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
|
||||
|
||||
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
|
||||
linear.weight = weight
|
||||
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
|
||||
linear = linear.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output_native = linear.forward_native(x)
|
||||
output_cuda = linear.forward_cuda(x)
|
||||
|
||||
# Compare outputs
|
||||
max_diff = torch.abs(output_cuda - output_native).max()
|
||||
assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output"
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_accelerator
|
||||
@require_accelerate
|
||||
|
||||
Reference in New Issue
Block a user