mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into integrations/first-block-cache-2
This commit is contained in:
@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## Loading GGUF quantized checkpoints for HiDream-I1
|
||||
|
||||
GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
|
||||
transformer = HiDreamImageTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
## HiDreamImageTransformer2DModel
|
||||
|
||||
[[autodoc]] HiDreamImageTransformer2DModel
|
||||
|
||||
@@ -78,6 +78,23 @@ For more information and different options about `torch.compile`, refer to the [
|
||||
> [!TIP]
|
||||
> Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial.
|
||||
|
||||
### Regional compilation
|
||||
|
||||
Compiling the whole model usually has a big problem space for optimization. Models are often composed of multiple repeated blocks. [Regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) compiles the repeated block first (a transformer encoder block, for example), so that the Torch compiler would re-use its cached/optimized generated code for the other blocks, reducing (often massively) the cold start compilation time observed on the first inference call.
|
||||
|
||||
Enabling regional compilation might require simple yet intrusive changes to the
|
||||
modeling code. However, 🤗 Accelerate provides a utility [`compile_regions()`](https://huggingface.co/docs/accelerate/main/en/usage_guides/compilation#how-to-use-regional-compilation) which automatically compiles
|
||||
the repeated blocks of the provided `nn.Module` sequentially, and the rest of the model separately. This helps with reducing cold start time while keeping most (if not all) of the speedup you would get from full compilation.
|
||||
|
||||
```py
|
||||
# Make sure you're on the latest `accelerate`: `pip install -U accelerate`.
|
||||
from accelerate.utils import compile_regions
|
||||
|
||||
pipe.unet = compile_regions(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
As you may have noticed `compile_regions()` takes the same arguments as `torch.compile()`, allowing flexibility.
|
||||
|
||||
## Benchmark
|
||||
|
||||
We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).
|
||||
|
||||
@@ -31,6 +31,7 @@ from .single_file_utils import (
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
@@ -133,6 +134,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"HiDreamImageTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
],
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||
model_type = "wan-t2v-14B"
|
||||
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
||||
model_type = "hidream"
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict[key] = value
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
@@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
|
||||
|
||||
|
||||
@@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
|
||||
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
>>> from diffusers import HiDreamImagePipeline
|
||||
|
||||
|
||||
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
... output_hidden_states=True,
|
||||
|
||||
@@ -408,6 +408,18 @@ class GGUFParameter(torch.nn.Parameter):
|
||||
def as_tensor(self):
|
||||
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
|
||||
|
||||
@staticmethod
|
||||
def _extract_quant_type(args):
|
||||
# When converting from original format checkpoints we often use splits, cats etc on tensors
|
||||
# this method ensures that the returned tensor type from those operations remains GGUFParameter
|
||||
# so that we preserve quant_type information
|
||||
for arg in args:
|
||||
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
||||
return arg[0].quant_type
|
||||
if isinstance(arg, GGUFParameter):
|
||||
return arg.quant_type
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
@@ -415,22 +427,13 @@ class GGUFParameter(torch.nn.Parameter):
|
||||
|
||||
result = super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
# When converting from original format checkpoints we often use splits, cats etc on tensors
|
||||
# this method ensures that the returned tensor type from those operations remains GGUFParameter
|
||||
# so that we preserve quant_type information
|
||||
quant_type = None
|
||||
for arg in args:
|
||||
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
||||
quant_type = arg[0].quant_type
|
||||
break
|
||||
if isinstance(arg, GGUFParameter):
|
||||
quant_type = arg.quant_type
|
||||
break
|
||||
if isinstance(result, torch.Tensor):
|
||||
quant_type = cls._extract_quant_type(args)
|
||||
return cls(result, quant_type=quant_type)
|
||||
# Handle tuples and lists
|
||||
elif isinstance(result, (tuple, list)):
|
||||
elif type(result) in (list, tuple):
|
||||
# Preserve the original type (tuple or list)
|
||||
quant_type = cls._extract_quant_type(args)
|
||||
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
|
||||
return type(result)(wrapped)
|
||||
else:
|
||||
|
||||
@@ -1580,6 +1580,34 @@ class ModelTesterMixin:
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
|
||||
|
||||
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
|
||||
@require_torch_accelerator
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
if not getattr(model, "_supports_group_offloading", True):
|
||||
return
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||
)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
|
||||
@@ -12,6 +12,7 @@ from diffusers import (
|
||||
FluxPipeline,
|
||||
FluxTransformer2DModel,
|
||||
GGUFQuantizationConfig,
|
||||
HiDreamImageTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
@@ -549,3 +550,30 @@ class FluxControlLoRAGGUFTests(unittest.TestCase):
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
|
||||
class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
|
||||
torch_dtype = torch.bfloat16
|
||||
model_cls = HiDreamImageTransformer2DModel
|
||||
expected_memory_use_in_gb = 8
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_t5": torch.randn(
|
||||
(1, 128, 4096),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"encoder_hidden_states_llama3": torch.randn(
|
||||
(32, 1, 128, 4096),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"pooled_embeds": torch.randn(
|
||||
(1, 2048),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user