mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
336 lines
12 KiB
Python
336 lines
12 KiB
Python
import gc
|
|
import tempfile
|
|
import unittest
|
|
|
|
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
|
from diffusers.models.attention_processor import Attention
|
|
from diffusers.utils import is_optimum_quanto_available, is_torch_available
|
|
|
|
from ...testing_utils import (
|
|
backend_empty_cache,
|
|
backend_reset_peak_memory_stats,
|
|
enable_full_determinism,
|
|
nightly,
|
|
numpy_cosine_similarity_distance,
|
|
require_accelerate,
|
|
require_accelerator,
|
|
require_torch_cuda_compatibility,
|
|
torch_device,
|
|
)
|
|
|
|
|
|
if is_optimum_quanto_available():
|
|
from optimum.quanto import QLinear
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from ..utils import LoRALayer, get_memory_consumption_stat
|
|
|
|
enable_full_determinism()
|
|
|
|
|
|
@nightly
|
|
@require_accelerator
|
|
@require_accelerate
|
|
class QuantoBaseTesterMixin:
|
|
model_id = None
|
|
pipeline_model_id = None
|
|
model_cls = None
|
|
torch_dtype = torch.bfloat16
|
|
# the expected reduction in peak memory used compared to an unquantized model expressed as a percentage
|
|
expected_memory_reduction = 0.0
|
|
keep_in_fp32_module = ""
|
|
modules_to_not_convert = ""
|
|
_test_torch_compile = False
|
|
|
|
def setUp(self):
|
|
backend_reset_peak_memory_stats(torch_device)
|
|
backend_empty_cache(torch_device)
|
|
gc.collect()
|
|
|
|
def tearDown(self):
|
|
backend_reset_peak_memory_stats(torch_device)
|
|
backend_empty_cache(torch_device)
|
|
gc.collect()
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {"weights_dtype": "float8"}
|
|
|
|
def get_dummy_model_init_kwargs(self):
|
|
return {
|
|
"pretrained_model_name_or_path": self.model_id,
|
|
"torch_dtype": self.torch_dtype,
|
|
"quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()),
|
|
}
|
|
|
|
def test_quanto_layers(self):
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
assert isinstance(module, QLinear)
|
|
|
|
def test_quanto_memory_usage(self):
|
|
inputs = self.get_dummy_inputs()
|
|
inputs = {
|
|
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
|
|
}
|
|
|
|
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
|
|
unquantized_model.to(torch_device)
|
|
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
|
|
|
|
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
quantized_model.to(torch_device)
|
|
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
|
|
|
|
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
|
|
|
|
def test_keep_modules_in_fp32(self):
|
|
r"""
|
|
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
|
|
Also ensures if inference works.
|
|
"""
|
|
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
|
|
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
|
|
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
model.to(torch_device)
|
|
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
if name in model._keep_in_fp32_modules:
|
|
assert module.weight.dtype == torch.float32
|
|
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
|
|
|
|
def test_modules_to_not_convert(self):
|
|
init_kwargs = self.get_dummy_model_init_kwargs()
|
|
|
|
quantization_config_kwargs = self.get_dummy_init_kwargs()
|
|
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
|
|
quantization_config = QuantoConfig(**quantization_config_kwargs)
|
|
|
|
init_kwargs.update({"quantization_config": quantization_config})
|
|
|
|
model = self.model_cls.from_pretrained(**init_kwargs)
|
|
model.to(torch_device)
|
|
|
|
for name, module in model.named_modules():
|
|
if name in self.modules_to_not_convert:
|
|
assert not isinstance(module, QLinear)
|
|
|
|
def test_dtype_assignment(self):
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
|
|
with self.assertRaises(ValueError):
|
|
# Tries with a `dtype`
|
|
model.to(torch.float16)
|
|
|
|
with self.assertRaises(ValueError):
|
|
# Tries with a `device` and `dtype`
|
|
device_0 = f"{torch_device}:0"
|
|
model.to(device=device_0, dtype=torch.float16)
|
|
|
|
with self.assertRaises(ValueError):
|
|
# Tries with a cast
|
|
model.float()
|
|
|
|
with self.assertRaises(ValueError):
|
|
# Tries with a cast
|
|
model.half()
|
|
|
|
# This should work
|
|
model.to(torch_device)
|
|
|
|
def test_serialization(self):
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
inputs = self.get_dummy_inputs()
|
|
|
|
model.to(torch_device)
|
|
with torch.no_grad():
|
|
model_output = model(**inputs)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
model.save_pretrained(tmp_dir)
|
|
saved_model = self.model_cls.from_pretrained(
|
|
tmp_dir,
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
|
|
saved_model.to(torch_device)
|
|
with torch.no_grad():
|
|
saved_model_output = saved_model(**inputs)
|
|
|
|
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
|
|
|
|
def test_torch_compile(self):
|
|
if not self._test_torch_compile:
|
|
return
|
|
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
|
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
|
|
|
|
model.to(torch_device)
|
|
with torch.no_grad():
|
|
model_output = model(**self.get_dummy_inputs()).sample
|
|
|
|
compiled_model.to(torch_device)
|
|
with torch.no_grad():
|
|
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
|
|
|
|
model_output = model_output.detach().float().cpu().numpy()
|
|
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
|
|
|
|
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
|
|
assert max_diff < 1e-3
|
|
|
|
def test_device_map_error(self):
|
|
with self.assertRaises(ValueError):
|
|
_ = self.model_cls.from_pretrained(
|
|
**self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"}
|
|
)
|
|
|
|
|
|
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
|
|
model_id = "hf-internal-testing/tiny-flux-transformer"
|
|
model_cls = FluxTransformer2DModel
|
|
pipeline_cls = FluxPipeline
|
|
torch_dtype = torch.bfloat16
|
|
keep_in_fp32_module = "proj_out"
|
|
modules_to_not_convert = ["proj_out"]
|
|
_test_torch_compile = False
|
|
|
|
def get_dummy_inputs(self):
|
|
return {
|
|
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
|
torch_device, self.torch_dtype
|
|
),
|
|
"encoder_hidden_states": torch.randn(
|
|
(1, 512, 4096),
|
|
generator=torch.Generator("cpu").manual_seed(0),
|
|
).to(torch_device, self.torch_dtype),
|
|
"pooled_projections": torch.randn(
|
|
(1, 768),
|
|
generator=torch.Generator("cpu").manual_seed(0),
|
|
).to(torch_device, self.torch_dtype),
|
|
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
|
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
|
torch_device, self.torch_dtype
|
|
),
|
|
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
|
torch_device, self.torch_dtype
|
|
),
|
|
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
|
}
|
|
|
|
def get_dummy_training_inputs(self, device=None, seed: int = 0):
|
|
batch_size = 1
|
|
num_latent_channels = 4
|
|
num_image_channels = 3
|
|
height = width = 4
|
|
sequence_length = 48
|
|
embedding_dim = 32
|
|
|
|
torch.manual_seed(seed)
|
|
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
|
|
|
|
torch.manual_seed(seed)
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
|
|
device, dtype=torch.bfloat16
|
|
)
|
|
|
|
torch.manual_seed(seed)
|
|
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
|
|
|
|
torch.manual_seed(seed)
|
|
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
|
|
|
|
torch.manual_seed(seed)
|
|
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
|
|
|
|
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
|
|
|
|
return {
|
|
"hidden_states": hidden_states,
|
|
"encoder_hidden_states": encoder_hidden_states,
|
|
"pooled_projections": pooled_prompt_embeds,
|
|
"txt_ids": text_ids,
|
|
"img_ids": image_ids,
|
|
"timestep": timestep,
|
|
}
|
|
|
|
def test_model_cpu_offload(self):
|
|
init_kwargs = self.get_dummy_init_kwargs()
|
|
transformer = self.model_cls.from_pretrained(
|
|
"hf-internal-testing/tiny-flux-pipe",
|
|
quantization_config=QuantoConfig(**init_kwargs),
|
|
subfolder="transformer",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
pipe = self.pipeline_cls.from_pretrained(
|
|
"hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16
|
|
)
|
|
pipe.enable_model_cpu_offload(device=torch_device)
|
|
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
|
|
|
|
def test_training(self):
|
|
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
|
|
quantized_model = self.model_cls.from_pretrained(
|
|
"hf-internal-testing/tiny-flux-pipe",
|
|
subfolder="transformer",
|
|
quantization_config=quantization_config,
|
|
torch_dtype=torch.bfloat16,
|
|
).to(torch_device)
|
|
|
|
for param in quantized_model.parameters():
|
|
# freeze the model as only adapter layers will be trained
|
|
param.requires_grad = False
|
|
if param.ndim == 1:
|
|
param.data = param.data.to(torch.float32)
|
|
|
|
for _, module in quantized_model.named_modules():
|
|
if isinstance(module, Attention):
|
|
module.to_q = LoRALayer(module.to_q, rank=4)
|
|
module.to_k = LoRALayer(module.to_k, rank=4)
|
|
module.to_v = LoRALayer(module.to_v, rank=4)
|
|
|
|
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
|
|
inputs = self.get_dummy_training_inputs(torch_device)
|
|
output = quantized_model(**inputs)[0]
|
|
output.norm().backward()
|
|
|
|
for module in quantized_model.modules():
|
|
if isinstance(module, LoRALayer):
|
|
self.assertTrue(module.adapter[1].weight.grad is not None)
|
|
|
|
|
|
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.6
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {"weights_dtype": "float8"}
|
|
|
|
|
|
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.6
|
|
_test_torch_compile = True
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {"weights_dtype": "int8"}
|
|
|
|
|
|
@require_torch_cuda_compatibility(8.0)
|
|
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.55
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {"weights_dtype": "int4"}
|
|
|
|
|
|
@require_torch_cuda_compatibility(8.0)
|
|
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
|
expected_memory_reduction = 0.65
|
|
|
|
def get_dummy_init_kwargs(self):
|
|
return {"weights_dtype": "int2"}
|