1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into integrations/first-block-cache-2

This commit is contained in:
Aryan
2025-05-15 21:38:57 +02:00
9 changed files with 127 additions and 17 deletions

View File

@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## Loading GGUF quantized checkpoints for HiDream-I1
GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`
```python
import torch
from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
transformer = HiDreamImageTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16
)
```
## HiDreamImageTransformer2DModel
[[autodoc]] HiDreamImageTransformer2DModel

View File

@@ -78,6 +78,23 @@ For more information and different options about `torch.compile`, refer to the [
> [!TIP]
> Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial.
### Regional compilation
Compiling the whole model usually has a big problem space for optimization. Models are often composed of multiple repeated blocks. [Regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) compiles the repeated block first (a transformer encoder block, for example), so that the Torch compiler would re-use its cached/optimized generated code for the other blocks, reducing (often massively) the cold start compilation time observed on the first inference call.
Enabling regional compilation might require simple yet intrusive changes to the
modeling code. However, 🤗 Accelerate provides a utility [`compile_regions()`](https://huggingface.co/docs/accelerate/main/en/usage_guides/compilation#how-to-use-regional-compilation) which automatically compiles
the repeated blocks of the provided `nn.Module` sequentially, and the rest of the model separately. This helps with reducing cold start time while keeping most (if not all) of the speedup you would get from full compilation.
```py
# Make sure you're on the latest `accelerate`: `pip install -U accelerate`.
from accelerate.utils import compile_regions
pipe.unet = compile_regions(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
As you may have noticed `compile_regions()` takes the same arguments as `torch.compile()`, allowing flexibility.
## Benchmark
We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).

View File

@@ -31,6 +31,7 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
@@ -133,6 +134,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
"HiDreamImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
}

View File

@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
}
# Use to configure model sample size when original config is provided
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
else:
model_type = "v1"
@@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
converted_state_dict[key] = value
return converted_state_dict
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
return checkpoint

View File

@@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
@@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module):
)
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]

View File

@@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
... output_hidden_states=True,

View File

@@ -408,6 +408,18 @@ class GGUFParameter(torch.nn.Parameter):
def as_tensor(self):
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
@staticmethod
def _extract_quant_type(args):
# When converting from original format checkpoints we often use splits, cats etc on tensors
# this method ensures that the returned tensor type from those operations remains GGUFParameter
# so that we preserve quant_type information
for arg in args:
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
return arg[0].quant_type
if isinstance(arg, GGUFParameter):
return arg.quant_type
return None
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
@@ -415,22 +427,13 @@ class GGUFParameter(torch.nn.Parameter):
result = super().__torch_function__(func, types, args, kwargs)
# When converting from original format checkpoints we often use splits, cats etc on tensors
# this method ensures that the returned tensor type from those operations remains GGUFParameter
# so that we preserve quant_type information
quant_type = None
for arg in args:
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
quant_type = arg[0].quant_type
break
if isinstance(arg, GGUFParameter):
quant_type = arg.quant_type
break
if isinstance(result, torch.Tensor):
quant_type = cls._extract_quant_type(args)
return cls(result, quant_type=quant_type)
# Handle tuples and lists
elif isinstance(result, (tuple, list)):
elif type(result) in (list, tuple):
# Preserve the original type (tuple or list)
quant_type = cls._extract_quant_type(args)
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
return type(result)(wrapped)
else:

View File

@@ -1580,6 +1580,34 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
model.to(torch_device)
model.eval()
_ = model(**inputs_dict)[0]
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
storage_dtype, compute_dtype = torch.float16, torch.float32
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
model.enable_group_offload(
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)

View File

@@ -12,6 +12,7 @@ from diffusers import (
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
@@ -549,3 +550,30 @@ class FluxControlLoRAGGUFTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
torch_dtype = torch.bfloat16
model_cls = HiDreamImageTransformer2DModel
expected_memory_use_in_gb = 8
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states_t5": torch.randn(
(1, 128, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"encoder_hidden_states_llama3": torch.randn(
(32, 1, 128, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"pooled_embeds": torch.randn(
(1, 2048),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
}