1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/tests/quantization/modelopt/test_modelopt.py
Ishan Modi 4acbfbf13b [Quantization] Add TRT-ModelOpt as a Backend (#11173)
* initial commit

* update

* updates

* update

* update

* update

* update

* update

* update

* addressed PR comments

* update

* addressed PR comments

* update

* update

* update

* update

* update

* update

* updates

* update

* update

* addressed PR comments

* updates

* code formatting

* update

* addressed PR comments

* addressed PR comments

* addressed PR comments

* addressed PR comments

* fix docs and dependencies

* fixed dependency test

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-09-03 10:14:52 +05:30

307 lines
11 KiB
Python

import gc
import tempfile
import unittest
from diffusers import NVIDIAModelOptConfig, SD3Transformer2DModel, StableDiffusion3Pipeline
from diffusers.utils import is_nvidia_modelopt_available, is_torch_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_reset_peak_memory_stats,
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_big_accelerator,
require_modelopt_version_greater_or_equal,
require_torch_cuda_compatibility,
torch_device,
)
if is_nvidia_modelopt_available():
import modelopt.torch.quantization as mtq
if is_torch_available():
import torch
from ..utils import LoRALayer, get_memory_consumption_stat
enable_full_determinism()
@nightly
@require_big_accelerator
@require_accelerate
@require_modelopt_version_greater_or_equal("0.33.1")
class ModelOptBaseTesterMixin:
model_id = "hf-internal-testing/tiny-sd3-pipe"
model_cls = SD3Transformer2DModel
pipeline_cls = StableDiffusion3Pipeline
torch_dtype = torch.bfloat16
expected_memory_reduction = 0.0
keep_in_fp32_module = ""
modules_to_not_convert = ""
_test_torch_compile = False
def setUp(self):
backend_reset_peak_memory_stats(torch_device)
backend_empty_cache(torch_device)
gc.collect()
def tearDown(self):
backend_reset_peak_memory_stats(torch_device)
backend_empty_cache(torch_device)
gc.collect()
def get_dummy_init_kwargs(self):
return {"quant_type": "FP8"}
def get_dummy_model_init_kwargs(self):
return {
"pretrained_model_name_or_path": self.model_id,
"torch_dtype": self.torch_dtype,
"quantization_config": NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()),
"subfolder": "transformer",
}
def test_modelopt_layers(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
assert mtq.utils.is_quantized(module)
def test_modelopt_memory_usage(self):
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
}
unquantized_model = self.model_cls.from_pretrained(
self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer"
)
unquantized_model.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
quantized_model.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
def test_keep_modules_in_fp32(self):
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
model.to(torch_device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
assert module.weight.dtype == torch.float32
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
def test_modules_to_not_convert(self):
init_kwargs = self.get_dummy_model_init_kwargs()
quantization_config_kwargs = self.get_dummy_init_kwargs()
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
quantization_config = NVIDIAModelOptConfig(**quantization_config_kwargs)
init_kwargs.update({"quantization_config": quantization_config})
model = self.model_cls.from_pretrained(**init_kwargs)
model.to(torch_device)
for name, module in model.named_modules():
if name in self.modules_to_not_convert:
assert not mtq.utils.is_quantized(module)
def test_dtype_assignment(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
with self.assertRaises(ValueError):
model.to(torch.float16)
with self.assertRaises(ValueError):
device_0 = f"{torch_device}:0"
model.to(device=device_0, dtype=torch.float16)
with self.assertRaises(ValueError):
model.float()
with self.assertRaises(ValueError):
model.half()
model.to(torch_device)
def test_serialization(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
inputs = self.get_dummy_inputs()
model.to(torch_device)
with torch.no_grad():
model_output = model(**inputs)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
saved_model = self.model_cls.from_pretrained(
tmp_dir,
torch_dtype=torch.bfloat16,
)
saved_model.to(torch_device)
with torch.no_grad():
saved_model_output = saved_model(**inputs)
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
def test_torch_compile(self):
if not self._test_torch_compile:
return
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
model.to(torch_device)
with torch.no_grad():
model_output = model(**self.get_dummy_inputs()).sample
compiled_model.to(torch_device)
with torch.no_grad():
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
model_output = model_output.detach().float().cpu().numpy()
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
assert max_diff < 1e-3
def test_device_map_error(self):
with self.assertRaises(ValueError):
_ = self.model_cls.from_pretrained(
**self.get_dummy_model_init_kwargs(),
device_map={0: "8GB", "cpu": "16GB"},
)
def get_dummy_inputs(self):
batch_size = 1
seq_len = 16
height = width = 32
num_latent_channels = 4
caption_channels = 8
torch.manual_seed(0)
hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(
torch_device, dtype=torch.bfloat16
)
encoder_hidden_states = torch.randn((batch_size, seq_len, caption_channels)).to(
torch_device, dtype=torch.bfloat16
)
timestep = torch.tensor([1.0]).to(torch_device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
def test_model_cpu_offload(self):
init_kwargs = self.get_dummy_init_kwargs()
transformer = self.model_cls.from_pretrained(
self.model_id,
quantization_config=NVIDIAModelOptConfig(**init_kwargs),
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
pipe = self.pipeline_cls.from_pretrained(self.model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload(device=torch_device)
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
def test_training(self):
quantization_config = NVIDIAModelOptConfig(**self.get_dummy_init_kwargs())
quantized_model = self.model_cls.from_pretrained(
self.model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for param in quantized_model.parameters():
param.requires_grad = False
if param.ndim == 1:
param.data = param.data.to(torch.float32)
for _, module in quantized_model.named_modules():
if hasattr(module, "to_q"):
module.to_q = LoRALayer(module.to_q, rank=4)
if hasattr(module, "to_k"):
module.to_k = LoRALayer(module.to_k, rank=4)
if hasattr(module, "to_v"):
module.to_v = LoRALayer(module.to_v, rank=4)
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
inputs = self.get_dummy_inputs()
output = quantized_model(**inputs)[0]
output.norm().backward()
for module in quantized_model.modules():
if isinstance(module, LoRALayer):
self.assertTrue(module.adapter[1].weight.grad is not None)
class SanaTransformerFP8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
expected_memory_reduction = 0.6
def get_dummy_init_kwargs(self):
return {"quant_type": "FP8"}
class SanaTransformerINT8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
expected_memory_reduction = 0.6
_test_torch_compile = True
def get_dummy_init_kwargs(self):
return {"quant_type": "INT8"}
@require_torch_cuda_compatibility(8.0)
class SanaTransformerINT4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
expected_memory_reduction = 0.55
def get_dummy_init_kwargs(self):
return {
"quant_type": "INT4",
"block_quantize": 128,
"channel_quantize": -1,
"disable_conv_quantization": True,
}
@require_torch_cuda_compatibility(8.0)
class SanaTransformerNF4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
expected_memory_reduction = 0.65
def get_dummy_init_kwargs(self):
return {
"quant_type": "NF4",
"block_quantize": 128,
"channel_quantize": -1,
"scale_block_quantize": 8,
"scale_channel_quantize": -1,
"modules_to_not_convert": ["conv"],
}
@require_torch_cuda_compatibility(8.0)
class SanaTransformerNVFP4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
expected_memory_reduction = 0.65
def get_dummy_init_kwargs(self):
return {
"quant_type": "NVFP4",
"block_quantize": 128,
"channel_quantize": -1,
"scale_block_quantize": 8,
"scale_channel_quantize": -1,
"modules_to_not_convert": ["conv"],
}