1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into dependabot/pip/examples/server/jinja2-3.1.6

This commit is contained in:
YiYi Xu
2025-04-08 09:25:19 -10:00
committed by GitHub
11 changed files with 262 additions and 19 deletions

View File

@@ -417,7 +417,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: []
additional_deps: ["peft"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []

View File

@@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch
# We can utilize the enable_group_offload method for Diffusers model implementations
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
# Uncomment the following to also allow recording the current streams.
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)
# For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
@@ -205,6 +208,7 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this.
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].

View File

@@ -56,6 +56,7 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage=False,
onload_self: bool = True,
) -> None:
@@ -68,11 +69,14 @@ class ModuleGroup:
self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict()
if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -112,6 +116,8 @@ class ModuleGroup:
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
current_stream = torch.cuda.current_stream() if self.record_stream else None
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
@@ -122,14 +128,22 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
else:
for group_module in self.modules:
@@ -143,11 +157,14 @@ class ModuleGroup:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
if not self.record_stream:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -331,6 +348,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
@@ -378,6 +396,10 @@ def apply_group_offloading(
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
@@ -417,11 +439,24 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
module=module,
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
module=module,
offload_device=offload_device,
onload_device=onload_device,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -434,6 +469,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
@@ -453,6 +489,14 @@ def _apply_group_offloading_block_level(
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
# Create module groups for ModuleList and Sequential blocks
@@ -475,6 +519,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
)
@@ -512,6 +557,7 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
@@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level(
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
# Create module groups for leaf modules and apply group offloading hooks
@@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
@@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
@@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)

View File

@@ -22,6 +22,8 @@ from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_bitsandbytes_available,
is_gguf_available,
is_peft_available,
is_peft_version,
is_torch_version,
@@ -68,6 +70,49 @@ TRANSFORMER_NAME = "transformer"
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
def _maybe_dequantize_weight_for_expanded_lora(model, module):
if is_bitsandbytes_available():
from ..quantizers.bitsandbytes import dequantize_bnb_weight
if is_gguf_available():
from ..quantizers.gguf.utils import dequantize_gguf_tensor
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_gguf_quantized and not is_gguf_available():
raise ValueError(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
)
weight_on_cpu = False
if not module.weight.is_cuda:
weight_on_cpu = True
if is_bnb_4bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.cuda() if weight_on_cpu else module.weight,
state=module.weight.quant_state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor(
module.weight.cuda() if weight_on_cpu else module.weight,
)
module_weight = module_weight.to(model.dtype)
else:
module_weight = module.weight.data
if weight_on_cpu:
module_weight = module_weight.cpu()
return module_weight
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -2267,6 +2312,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
overwritten_params = {}
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
is_quantized = hasattr(transformer, "hf_quantizer")
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
@@ -2291,9 +2337,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if tuple(module_weight_shape) == (out_features, in_features):
continue
# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
@@ -2316,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
if is_quantized:
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2327,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
@@ -2416,7 +2464,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
if weight.__class__.__name__ == "Params4bit":
return weight.quant_state.shape
elif weight.__class__.__name__ == "GGUFParameter":
return weight.quant_shape
else:
return weight.shape
if base_module is not None:
return _get_weight_shape(base_module.weight)

View File

@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage=False,
) -> None:
r"""
@@ -594,6 +595,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group,
non_blocking,
use_stream,
record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)

View File

@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
models by reducing the precision of the weights and activations, thus making models more efficient in terms
of both storage and computation.
"""
model, has_been_replaced = _replace_with_bnb_linear(
model, modules_to_not_convert, current_key_name, quantization_config
)
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
has_been_replaced = any(
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
@@ -283,16 +285,18 @@ def dequantize_and_replace(
modules_to_not_convert=None,
quantization_config=None,
):
model, has_been_replaced = _dequantize_and_replace(
model, _ = _dequantize_and_replace(
model,
dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
has_been_replaced = any(
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
"Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
)
return model

View File

@@ -400,6 +400,8 @@ class GGUFParameter(torch.nn.Parameter):
data = data if data is not None else torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quant_type = quant_type
block_size, type_size = GGML_QUANT_SIZES[quant_type]
self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
return self

View File

@@ -1525,8 +1525,9 @@ class ModelTesterMixin:
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)
@parameterized.expand([False, True])
@require_torch_gpu
def test_group_offloading(self):
def test_group_offloading(self, record_stream):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
@@ -1566,7 +1567,9 @@ class ModelTesterMixin:
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))

View File

@@ -21,8 +21,15 @@ import numpy as np
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download
from PIL import Image
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxControlPipeline,
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
@@ -63,6 +70,8 @@ if is_torch_available():
if is_bitsandbytes_available():
import bitsandbytes as bnb
from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@@ -364,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests):
assert key_to_target in str(err_context.exception)
def test_bnb_4bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)
class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
@@ -696,6 +717,42 @@ class SlowBnb4BitFluxTests(Base4bitTests):
self.assertTrue(max_diff < 1e-3)
@require_transformers_version_greater("4.44.0")
@require_peft_backend
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
self.pipeline_4bit.enable_model_cpu_offload()
def tearDown(self):
del self.pipeline_4bit
gc.collect()
torch.cuda.empty_cache()
def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
output = self.pipeline_4bit(
prompt=self.prompt,
control_image=Image.new(mode="RGB", size=(256, 256)),
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self):

View File

@@ -68,6 +68,8 @@ if is_torch_available():
if is_bitsandbytes_available():
import bitsandbytes as bnb
from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests):
# Check that this does not throw an error
_ = self.model_fp16.to(torch_device)
def test_bnb_8bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)
class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:

View File

@@ -8,12 +8,14 @@ import torch.nn as nn
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FluxControlPipeline,
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
is_gguf_available,
nightly,
@@ -21,6 +23,7 @@ from diffusers.utils.testing_utils import (
require_accelerate,
require_big_gpu_with_torch_cuda,
require_gguf_version_greater_or_equal,
require_peft_backend,
torch_device,
)
@@ -456,3 +459,46 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
@require_peft_backend
@nightly
@require_big_gpu_with_torch_cuda
@require_accelerate
@require_gguf_version_greater_or_equal("0.10.0")
class FluxControlLoRAGGUFTests(unittest.TestCase):
def test_lora_loading(self):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png"
)
output = pipe(
prompt=prompt,
control_image=control_image,
height=256,
width=256,
num_inference_steps=10,
guidance_scale=30.0,
output_type="np",
generator=torch.manual_seed(0),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)