From 4267d8f4eb98449d9d29ffbb087d9bdd7690dbab Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 15 May 2025 08:55:18 +0200 Subject: [PATCH 1/4] [Single File] GGUF/Single File Support for HiDream (#11550) * update * update * update * update * update * update * update --- .../api/models/hidream_image_transformer.md | 16 +++++++++++ src/diffusers/loaders/single_file_model.py | 5 ++++ src/diffusers/loaders/single_file_utils.py | 13 +++++++++ .../transformers/transformer_hidream_image.py | 4 +-- .../hidream_image/pipeline_hidream_image.py | 6 ++-- tests/quantization/gguf/test_gguf.py | 28 +++++++++++++++++++ 6 files changed, 67 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/models/hidream_image_transformer.md b/docs/source/en/api/models/hidream_image_transformer.md index 4218e7f56b..5dbf40b5a1 100644 --- a/docs/source/en/api/models/hidream_image_transformer.md +++ b/docs/source/en/api/models/hidream_image_transformer.md @@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16) ``` +## Loading GGUF quantized checkpoints for HiDream-I1 + +GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file` + +```python +import torch +from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel + +ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf" +transformer = HiDreamImageTransformer2DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16 +) +``` + ## HiDreamImageTransformer2DModel [[autodoc]] HiDreamImageTransformer2DModel diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index ade2e457d8..6919c4949d 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -31,6 +31,7 @@ from .single_file_utils import ( convert_autoencoder_dc_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_flux_transformer_checkpoint_to_diffusers, + convert_hidream_transformer_to_diffusers, convert_hunyuan_video_transformer_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, @@ -133,6 +134,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { "checkpoint_mapping_fn": convert_wan_vae_to_diffusers, "default_subfolder": "vae", }, + "HiDreamImageTransformer2DModel": { + "checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 3a2855df2d..5cdc381918 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = { ], "wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan_vae": "decoder.middle.0.residual.0.gamma", + "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias", } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, + "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"}, } # Use to configure model sample size when original config is provided @@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint): elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint: # All Wan models use the same VAE so we can use the same default model repo to fetch the config model_type = "wan-t2v-14B" + elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint: + model_type = "hidream" else: model_type = "v1" @@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs): converted_state_dict[key] = value return converted_state_dict + + +def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs): + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + return checkpoint diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 06f47fcbaf..77902dcf58 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers @@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module): ) -class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True _no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"] diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index 6fe74cbd9a..17bf0a3fe8 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch - >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM - >>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline + >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from diffusers import HiDreamImagePipeline - >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( ... "meta-llama/Meta-Llama-3.1-8B-Instruct", ... output_hidden_states=True, diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 9f54ecf6c6..ae3900459d 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -12,6 +12,7 @@ from diffusers import ( FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig, + HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, ) @@ -549,3 +550,30 @@ class FluxControlLoRAGGUFTests(unittest.TestCase): max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3) + + +class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf" + torch_dtype = torch.bfloat16 + model_cls = HiDreamImageTransformer2DModel + expected_memory_use_in_gb = 8 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states_t5": torch.randn( + (1, 128, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "encoder_hidden_states_llama3": torch.randn( + (32, 1, 128, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_embeds": torch.randn( + (1, 2048), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype), + } From 3a6caba8e47e367d092eeaa4c165902ab966c07f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 15 May 2025 02:08:18 -0700 Subject: [PATCH 2/4] [gguf] Refactor __torch_function__ to avoid unnecessary computation (#11551) * [gguf] Refactor __torch_function__ to avoid unnecessary computation This helps with torch.compile compilation lantency. Avoiding unnecessary computation should also lead to a slightly improved eager latency. * Apply style fixes --------- Co-authored-by: Sayak Paul Co-authored-by: github-actions[bot] --- src/diffusers/quantizers/gguf/utils.py | 27 ++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index de82dcab07..531fd61273 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -408,6 +408,18 @@ class GGUFParameter(torch.nn.Parameter): def as_tensor(self): return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad) + @staticmethod + def _extract_quant_type(args): + # When converting from original format checkpoints we often use splits, cats etc on tensors + # this method ensures that the returned tensor type from those operations remains GGUFParameter + # so that we preserve quant_type information + for arg in args: + if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): + return arg[0].quant_type + if isinstance(arg, GGUFParameter): + return arg.quant_type + return None + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: @@ -415,22 +427,13 @@ class GGUFParameter(torch.nn.Parameter): result = super().__torch_function__(func, types, args, kwargs) - # When converting from original format checkpoints we often use splits, cats etc on tensors - # this method ensures that the returned tensor type from those operations remains GGUFParameter - # so that we preserve quant_type information - quant_type = None - for arg in args: - if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): - quant_type = arg[0].quant_type - break - if isinstance(arg, GGUFParameter): - quant_type = arg.quant_type - break if isinstance(result, torch.Tensor): + quant_type = cls._extract_quant_type(args) return cls(result, quant_type=quant_type) # Handle tuples and lists - elif isinstance(result, (tuple, list)): + elif type(result) in (list, tuple): # Preserve the original type (tuple or list) + quant_type = cls._extract_quant_type(args) wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result] return type(result)(wrapped) else: From 20379d9d1395b8e95977faf80facff43065ba75f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 15 May 2025 17:16:44 +0530 Subject: [PATCH 3/4] [tests] add tests for combining layerwise upcasting and groupoffloading. (#11558) * add tests for combining layerwise upcasting and groupoffloading. * feedback --- tests/models/test_modeling_common.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 58edeb55c4..0b17d7977a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1580,6 +1580,34 @@ class ModelTesterMixin: self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) + @parameterized.expand([(False, "block_level"), (True, "leaf_level")]) + @require_torch_accelerator + @torch.no_grad() + def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type): + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + if not getattr(model, "_supports_group_offloading", True): + return + + model.to(torch_device) + model.eval() + _ = model(**inputs_dict)[0] + + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + storage_dtype, compute_dtype = torch.float16, torch.float32 + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**init_dict) + model.eval() + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} + model.enable_group_offload( + torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs + ) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + _ = model(**inputs_dict)[0] + def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: model = self.model_class(**self.init_dict) From 9836f0e000cfd826a7a5099002253ed2becc13e0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 15 May 2025 19:11:24 +0530 Subject: [PATCH 4/4] [docs] Regional compilation docs (#11556) * add regional compilation docs. * minor. * reviwer feedback. * Update docs/source/en/optimization/torch2.0.md Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --- docs/source/en/optimization/torch2.0.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/source/en/optimization/torch2.0.md b/docs/source/en/optimization/torch2.0.md index 01ea00310a..cc69eceff3 100644 --- a/docs/source/en/optimization/torch2.0.md +++ b/docs/source/en/optimization/torch2.0.md @@ -78,6 +78,23 @@ For more information and different options about `torch.compile`, refer to the [ > [!TIP] > Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial. +### Regional compilation + +Compiling the whole model usually has a big problem space for optimization. Models are often composed of multiple repeated blocks. [Regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) compiles the repeated block first (a transformer encoder block, for example), so that the Torch compiler would re-use its cached/optimized generated code for the other blocks, reducing (often massively) the cold start compilation time observed on the first inference call. + +Enabling regional compilation might require simple yet intrusive changes to the +modeling code. However, 🤗 Accelerate provides a utility [`compile_regions()`](https://huggingface.co/docs/accelerate/main/en/usage_guides/compilation#how-to-use-regional-compilation) which automatically compiles +the repeated blocks of the provided `nn.Module` sequentially, and the rest of the model separately. This helps with reducing cold start time while keeping most (if not all) of the speedup you would get from full compilation. + +```py +# Make sure you're on the latest `accelerate`: `pip install -U accelerate`. +from accelerate.utils import compile_regions + +pipe.unet = compile_regions(pipe.unet, mode="reduce-overhead", fullgraph=True) +``` + +As you may have noticed `compile_regions()` takes the same arguments as `torch.compile()`, allowing flexibility. + ## Benchmark We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).