diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index c61ac26928..7cee0bccc3 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -38,12 +38,12 @@ class NunchakuQuantizer(DiffusersQuantizer): requires_calibration = False required_packages = ["nunchaku", "accelerate"] - dtype_map = {"int4": torch.int8} - if is_fp8_available(): - dtype_map = {"nvfp4": torch.float8_e4m3fn} - def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) + dtype_map = {"int4": torch.int8} + if is_fp8_available(): + dtype_map = {"nvfp4": torch.float8_e4m3fn} + self.dtype_map = dtype_map def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available():