mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* initial commit * update * updates * update * update * update * update * update * update * addressed PR comments * update * addressed PR comments * update * update * update * update * update * update * updates * update * update * addressed PR comments * updates * code formatting * update * addressed PR comments * addressed PR comments * addressed PR comments * addressed PR comments * fix docs and dependencies * fixed dependency test --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
307 lines
11 KiB
Python
307 lines
11 KiB
Python
import gc
|
|
import tempfile
|
|
import unittest
|
|
|
|
from diffusers import NVIDIAModelOptConfig, SD3Transformer2DModel, StableDiffusion3Pipeline
|
|
from diffusers.utils import is_nvidia_modelopt_available, is_torch_available
|
|
from diffusers.utils.testing_utils import (
|
|
backend_empty_cache,
|
|
backend_reset_peak_memory_stats,
|
|
enable_full_determinism,
|
|
nightly,
|
|
numpy_cosine_similarity_distance,
|
|
require_accelerate,
|
|
require_big_accelerator,
|
|
require_modelopt_version_greater_or_equal,
|
|
require_torch_cuda_compatibility,
|
|
torch_device,
|
|
)
|
|
|
|
|
|
if is_nvidia_modelopt_available():
|
|
import modelopt.torch.quantization as mtq
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from ..utils import LoRALayer, get_memory_consumption_stat
|
|
|
|
enable_full_determinism()
|
|
|
|
|
|
@nightly
|
|
@require_big_accelerator
|
|
@require_accelerate
|
|
@require_modelopt_version_greater_or_equal("0.33.1")
|
|
class ModelOptBaseTesterMixin:
|
|
model_id = "hf-internal-testing/tiny-sd3-pipe"
|
|
model_cls = SD3Transformer2DModel
|
|
pipeline_cls = StableDiffusion3Pipeline
|
|
torch_dtype = torch.bfloat16
|
|
expected_memory_reduction = 0.0
|
|
keep_in_fp32_module = ""
|
|
modules_to_not_convert = ""
|
|
_test_torch_compile = False
|
|
|
|
def setUp(self):
|
|
backend_reset_peak_memory_stats(torch_device)
|
|
backend_empty_cache(torch_device)
|
|
gc.collect()
|
|
|
|
def tearDown(self):
|
|
backend_reset_peak_memory_stats(torch_device)
|
|
backend_empty_cache(torch_device)
|
|
gc.collect()
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {"quant_type": "FP8"}
|
|
|
|
def get_dummy_model_init_kwargs(self):
|
|
return {
|
|
"pretrained_model_name_or_path": self.model_id,
|
|
"torch_dtype": self.torch_dtype,
|
|
"quantization_config": NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()),
|
|
"subfolder": "transformer",
|
|
}
|
|
|
|
def test_modelopt_layers(self):
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
assert mtq.utils.is_quantized(module)
|
|
|
|
def test_modelopt_memory_usage(self):
|
|
inputs = self.get_dummy_inputs()
|
|
inputs = {
|
|
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
|
|
}
|
|
|
|
unquantized_model = self.model_cls.from_pretrained(
|
|
self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer"
|
|
)
|
|
unquantized_model.to(torch_device)
|
|
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
|
|
|
|
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
quantized_model.to(torch_device)
|
|
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
|
|
|
|
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
|
|
|
|
def test_keep_modules_in_fp32(self):
|
|
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
|
|
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
|
|
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
model.to(torch_device)
|
|
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
if name in model._keep_in_fp32_modules:
|
|
assert module.weight.dtype == torch.float32
|
|
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
|
|
|
|
def test_modules_to_not_convert(self):
|
|
init_kwargs = self.get_dummy_model_init_kwargs()
|
|
quantization_config_kwargs = self.get_dummy_init_kwargs()
|
|
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
|
|
quantization_config = NVIDIAModelOptConfig(**quantization_config_kwargs)
|
|
init_kwargs.update({"quantization_config": quantization_config})
|
|
|
|
model = self.model_cls.from_pretrained(**init_kwargs)
|
|
model.to(torch_device)
|
|
|
|
for name, module in model.named_modules():
|
|
if name in self.modules_to_not_convert:
|
|
assert not mtq.utils.is_quantized(module)
|
|
|
|
def test_dtype_assignment(self):
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
|
|
with self.assertRaises(ValueError):
|
|
model.to(torch.float16)
|
|
|
|
with self.assertRaises(ValueError):
|
|
device_0 = f"{torch_device}:0"
|
|
model.to(device=device_0, dtype=torch.float16)
|
|
|
|
with self.assertRaises(ValueError):
|
|
model.float()
|
|
|
|
with self.assertRaises(ValueError):
|
|
model.half()
|
|
|
|
model.to(torch_device)
|
|
|
|
def test_serialization(self):
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
inputs = self.get_dummy_inputs()
|
|
|
|
model.to(torch_device)
|
|
with torch.no_grad():
|
|
model_output = model(**inputs)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
model.save_pretrained(tmp_dir)
|
|
saved_model = self.model_cls.from_pretrained(
|
|
tmp_dir,
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
|
|
saved_model.to(torch_device)
|
|
with torch.no_grad():
|
|
saved_model_output = saved_model(**inputs)
|
|
|
|
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
|
|
|
|
def test_torch_compile(self):
|
|
if not self._test_torch_compile:
|
|
return
|
|
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
|
|
|
|
model.to(torch_device)
|
|
with torch.no_grad():
|
|
model_output = model(**self.get_dummy_inputs()).sample
|
|
|
|
compiled_model.to(torch_device)
|
|
with torch.no_grad():
|
|
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
|
|
|
|
model_output = model_output.detach().float().cpu().numpy()
|
|
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
|
|
|
|
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
|
|
assert max_diff < 1e-3
|
|
|
|
def test_device_map_error(self):
|
|
with self.assertRaises(ValueError):
|
|
_ = self.model_cls.from_pretrained(
|
|
**self.get_dummy_model_init_kwargs(),
|
|
device_map={0: "8GB", "cpu": "16GB"},
|
|
)
|
|
|
|
def get_dummy_inputs(self):
|
|
batch_size = 1
|
|
seq_len = 16
|
|
height = width = 32
|
|
num_latent_channels = 4
|
|
caption_channels = 8
|
|
|
|
torch.manual_seed(0)
|
|
hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(
|
|
torch_device, dtype=torch.bfloat16
|
|
)
|
|
encoder_hidden_states = torch.randn((batch_size, seq_len, caption_channels)).to(
|
|
torch_device, dtype=torch.bfloat16
|
|
)
|
|
timestep = torch.tensor([1.0]).to(torch_device, dtype=torch.bfloat16).expand(batch_size)
|
|
|
|
return {
|
|
"hidden_states": hidden_states,
|
|
"encoder_hidden_states": encoder_hidden_states,
|
|
"timestep": timestep,
|
|
}
|
|
|
|
def test_model_cpu_offload(self):
|
|
init_kwargs = self.get_dummy_init_kwargs()
|
|
transformer = self.model_cls.from_pretrained(
|
|
self.model_id,
|
|
quantization_config=NVIDIAModelOptConfig(**init_kwargs),
|
|
subfolder="transformer",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
pipe = self.pipeline_cls.from_pretrained(self.model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
|
pipe.enable_model_cpu_offload(device=torch_device)
|
|
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
|
|
|
|
def test_training(self):
|
|
quantization_config = NVIDIAModelOptConfig(**self.get_dummy_init_kwargs())
|
|
quantized_model = self.model_cls.from_pretrained(
|
|
self.model_id,
|
|
subfolder="transformer",
|
|
quantization_config=quantization_config,
|
|
torch_dtype=torch.bfloat16,
|
|
).to(torch_device)
|
|
|
|
for param in quantized_model.parameters():
|
|
param.requires_grad = False
|
|
if param.ndim == 1:
|
|
param.data = param.data.to(torch.float32)
|
|
|
|
for _, module in quantized_model.named_modules():
|
|
if hasattr(module, "to_q"):
|
|
module.to_q = LoRALayer(module.to_q, rank=4)
|
|
if hasattr(module, "to_k"):
|
|
module.to_k = LoRALayer(module.to_k, rank=4)
|
|
if hasattr(module, "to_v"):
|
|
module.to_v = LoRALayer(module.to_v, rank=4)
|
|
|
|
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
|
|
inputs = self.get_dummy_inputs()
|
|
output = quantized_model(**inputs)[0]
|
|
output.norm().backward()
|
|
|
|
for module in quantized_model.modules():
|
|
if isinstance(module, LoRALayer):
|
|
self.assertTrue(module.adapter[1].weight.grad is not None)
|
|
|
|
|
|
class SanaTransformerFP8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.6
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {"quant_type": "FP8"}
|
|
|
|
|
|
class SanaTransformerINT8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.6
|
|
_test_torch_compile = True
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {"quant_type": "INT8"}
|
|
|
|
|
|
@require_torch_cuda_compatibility(8.0)
|
|
class SanaTransformerINT4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.55
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {
|
|
"quant_type": "INT4",
|
|
"block_quantize": 128,
|
|
"channel_quantize": -1,
|
|
"disable_conv_quantization": True,
|
|
}
|
|
|
|
|
|
@require_torch_cuda_compatibility(8.0)
|
|
class SanaTransformerNF4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.65
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {
|
|
"quant_type": "NF4",
|
|
"block_quantize": 128,
|
|
"channel_quantize": -1,
|
|
"scale_block_quantize": 8,
|
|
"scale_channel_quantize": -1,
|
|
"modules_to_not_convert": ["conv"],
|
|
}
|
|
|
|
|
|
@require_torch_cuda_compatibility(8.0)
|
|
class SanaTransformerNVFP4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.65
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {
|
|
"quant_type": "NVFP4",
|
|
"block_quantize": 128,
|
|
"channel_quantize": -1,
|
|
"scale_block_quantize": 8,
|
|
"scale_channel_quantize": -1,
|
|
"modules_to_not_convert": ["conv"],
|
|
}
|