mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' into mochi
This commit is contained in:
@@ -188,6 +188,8 @@
|
||||
title: Metal Performance Shaders (MPS)
|
||||
- local: optimization/habana
|
||||
title: Habana Gaudi
|
||||
- local: optimization/neuron
|
||||
title: AWS Neuron
|
||||
title: Optimized hardware
|
||||
title: Accelerate inference and reduce memory
|
||||
- sections:
|
||||
|
||||
61
docs/source/en/optimization/neuron.md
Normal file
61
docs/source/en/optimization/neuron.md
Normal file
@@ -0,0 +1,61 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AWS Neuron
|
||||
|
||||
Diffusers functionalities are available on [AWS Inf2 instances](https://aws.amazon.com/ec2/instance-types/inf2/), which are EC2 instances powered by [Neuron machine learning accelerators](https://aws.amazon.com/machine-learning/inferentia/). These instances aim to provide better compute performance (higher throughput, lower latency) with good cost-efficiency, making them good candidates for AWS users to deploy diffusion models to production.
|
||||
|
||||
[Optimum Neuron](https://huggingface.co/docs/optimum-neuron/en/index) is the interface between Hugging Face libraries and AWS Accelerators, including AWS [Trainium](https://aws.amazon.com/machine-learning/trainium/) and AWS [Inferentia](https://aws.amazon.com/machine-learning/inferentia/). It supports many of the features in Diffusers with similar APIs, so it is easier to learn if you're already familiar with Diffusers. Once you have created an AWS Inf2 instance, install Optimum Neuron.
|
||||
|
||||
```bash
|
||||
python -m pip install --upgrade-strategy eager optimum[neuronx]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
We provide pre-built [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) (DLAMI) and Optimum Neuron containers for Amazon SageMaker. It's recommended to correctly set up your environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
The example below demonstrates how to generate images with the Stable Diffusion XL model on an inf2.8xlarge instance (you can switch to cheaper inf2.xlarge instances once the model is compiled). To generate some images, use the [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] class, which is similar to the [`StableDiffusionXLPipeline`] class in Diffusers.
|
||||
|
||||
Unlike Diffusers, you need to compile models in the pipeline to the Neuron format, `.neuron`. Launch the following command to export the model to the `.neuron` format.
|
||||
|
||||
```bash
|
||||
optimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \
|
||||
--batch_size 1 \
|
||||
--height 1024 `# height in pixels of generated image, eg. 768, 1024` \
|
||||
--width 1024 `# width in pixels of generated image, eg. 768, 1024` \
|
||||
--num_images_per_prompt 1 `# number of images to generate per prompt, defaults to 1` \
|
||||
--auto_cast matmul `# cast only matrix multiplication operations` \
|
||||
--auto_cast_type bf16 `# cast operations from FP32 to BF16` \
|
||||
sd_neuron_xl/
|
||||
```
|
||||
|
||||
Now generate some images with the pre-compiled SDXL model.
|
||||
|
||||
```python
|
||||
>>> from optimum.neuron import NeuronStableDiffusionXLPipeline
|
||||
|
||||
>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained("sd_neuron_xl/")
|
||||
>>> prompt = "a pig with wings flying in floating US dollar banknotes in the air, skyscrapers behind, warm color palette, muted colors, detailed, 8k"
|
||||
>>> image = stable_diffusion_xl(prompt).images[0]
|
||||
```
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/Jingya/document_images/resolve/main/optimum/neuron/sdxl_pig.png"
|
||||
width="256"
|
||||
height="256"
|
||||
alt="peggy generated by sdxl on inf2"
|
||||
/>
|
||||
|
||||
Feel free to check out more guides and examples on different use cases from the Optimum Neuron [documentation](https://huggingface.co/docs/optimum-neuron/en/inference_tutorials/stable_diffusion#generate-images-with-stable-diffusion-models-on-aws-inferentia)!
|
||||
@@ -2198,8 +2198,8 @@ def main(args):
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
model_input.shape[2],
|
||||
model_input.shape[3],
|
||||
model_input.shape[2] // 2,
|
||||
model_input.shape[3] // 2,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
@@ -2253,8 +2253,8 @@ def main(args):
|
||||
)[0]
|
||||
model_pred = FluxPipeline._unpack_latents(
|
||||
model_pred,
|
||||
height=int(model_input.shape[2] * vae_scale_factor / 2),
|
||||
width=int(model_input.shape[3] * vae_scale_factor / 2),
|
||||
height=model_input.shape[2] * vae_scale_factor,
|
||||
width=model_input.shape[3] * vae_scale_factor,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
)
|
||||
|
||||
|
||||
@@ -1256,8 +1256,8 @@ def main(args):
|
||||
|
||||
latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
|
||||
batch_size=pixel_latents_tmp.shape[0],
|
||||
height=pixel_latents_tmp.shape[2],
|
||||
width=pixel_latents_tmp.shape[3],
|
||||
height=pixel_latents_tmp.shape[2] // 2,
|
||||
width=pixel_latents_tmp.shape[3] // 2,
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype,
|
||||
)
|
||||
|
||||
@@ -1540,12 +1540,12 @@ def main(args):
|
||||
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
model_input.shape[2],
|
||||
model_input.shape[3],
|
||||
model_input.shape[2] // 2,
|
||||
model_input.shape[3] // 2,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
@@ -1601,8 +1601,8 @@ def main(args):
|
||||
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
|
||||
model_pred = FluxPipeline._unpack_latents(
|
||||
model_pred,
|
||||
height=int(model_input.shape[2] * vae_scale_factor / 2),
|
||||
width=int(model_input.shape[3] * vae_scale_factor / 2),
|
||||
height=model_input.shape[2] * vae_scale_factor,
|
||||
width=model_input.shape[3] * vae_scale_factor,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
)
|
||||
|
||||
|
||||
@@ -1645,12 +1645,12 @@ def main(args):
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
model_input.shape[2],
|
||||
model_input.shape[3],
|
||||
model_input.shape[2] // 2,
|
||||
model_input.shape[3] // 2,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
@@ -1704,8 +1704,8 @@ def main(args):
|
||||
)[0]
|
||||
model_pred = FluxPipeline._unpack_latents(
|
||||
model_pred,
|
||||
height=int(model_input.shape[2] * vae_scale_factor / 2),
|
||||
width=int(model_input.shape[3] * vae_scale_factor / 2),
|
||||
height=model_input.shape[2] * vae_scale_factor,
|
||||
width=model_input.shape[3] * vae_scale_factor,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
)
|
||||
|
||||
|
||||
166
examples/research_projects/flux_lora_quantization/README.md
Normal file
166
examples/research_projects/flux_lora_quantization/README.md
Normal file
@@ -0,0 +1,166 @@
|
||||
## LoRA fine-tuning Flux.1 Dev with quantization
|
||||
|
||||
> [!NOTE]
|
||||
> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
|
||||
|
||||
This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:
|
||||
|
||||
* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
|
||||
* `train_dreambooth_lora_flux_miniature.py` takes care of training:
|
||||
* Since we already precomputed the text embeddings, we don't load the text encoders.
|
||||
* We load the VAE and use it to precompute the image latents and we then delete it.
|
||||
* Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training.
|
||||
* Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.
|
||||
* Train!
|
||||
|
||||
To run training in a memory-optimized manner, we additionally use:
|
||||
|
||||
* 8Bit Adam
|
||||
* Gradient checkpointing
|
||||
|
||||
We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow.
|
||||
|
||||
## Training
|
||||
|
||||
Ensure you have installed the required libraries:
|
||||
|
||||
```bash
|
||||
pip install -U transformers accelerate bitsandbytes peft datasets
|
||||
pip install git+https://github.com/huggingface/diffusers -U
|
||||
```
|
||||
|
||||
Now, compute the text embeddings:
|
||||
|
||||
```bash
|
||||
python compute_embeddings.py
|
||||
```
|
||||
|
||||
It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
|
||||
|
||||
```bash
|
||||
huggingface-cli
|
||||
```
|
||||
|
||||
Then launch:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=accelerate.yaml \
|
||||
train_dreambooth_lora_flux_miniature.py \
|
||||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
||||
--data_df_path="embeddings.parquet" \
|
||||
--output_dir="yarn_art_lora_flux_nf4" \
|
||||
--mixed_precision="fp16" \
|
||||
--use_8bit_adam \
|
||||
--weighting_scheme="none" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--learning_rate=1e-4 \
|
||||
--guidance_scale=1 \
|
||||
--report_to="wandb" \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--cache_latents \
|
||||
--rank=4 \
|
||||
--max_train_steps=700 \
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
We can direcly pass a quantized checkpoint path, too:
|
||||
|
||||
```diff
|
||||
+ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
|
||||
```
|
||||
|
||||
Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`.
|
||||
|
||||
We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:
|
||||
|
||||
```bash
|
||||
pip install -Uq deepspeed
|
||||
```
|
||||
|
||||
And then launch:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=ds2.yaml \
|
||||
train_dreambooth_lora_flux_miniature.py \
|
||||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
||||
--data_df_path="embeddings.parquet" \
|
||||
--output_dir="yarn_art_lora_flux_nf4" \
|
||||
--mixed_precision="no" \
|
||||
--use_8bit_adam \
|
||||
--weighting_scheme="none" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--learning_rate=1e-4 \
|
||||
--guidance_scale=1 \
|
||||
--report_to="wandb" \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--cache_latents \
|
||||
--rank=4 \
|
||||
--max_train_steps=700 \
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:
|
||||
|
||||
1. First, load the original model and merge the LoRA params into it:
|
||||
|
||||
```py
|
||||
from diffusers import FluxPipeline
|
||||
import torch
|
||||
|
||||
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
pipeline.transformer.save_pretrained("fused_transformer")
|
||||
```
|
||||
|
||||
2. Quantize the model and run inference
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
|
||||
import torch
|
||||
|
||||
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
||||
bnb_4bit_compute_dtype = torch.float16
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
||||
)
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"fused_transformer",
|
||||
quantization_config=nf4_config,
|
||||
torch_dtype=bnb_4bit_compute_dtype,
|
||||
)
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
image = pipeline(
|
||||
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
|
||||
).images[0]
|
||||
image.save("yarn_merged.png")
|
||||
```
|
||||
|
||||
| Dequantize, merge, quantize | Merging directly into quantized model |
|
||||
|-------|-------|
|
||||
|  |  |
|
||||
|
||||
As we can notice the first column result follows the style more closely.
|
||||
@@ -0,0 +1,17 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: NO
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: true
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
|
||||
MAX_SEQ_LENGTH = 77
|
||||
OUTPUT_PATH = "embeddings.parquet"
|
||||
|
||||
|
||||
def generate_image_hash(image):
|
||||
return insecure_hashlib.sha256(image.tobytes()).hexdigest()
|
||||
|
||||
|
||||
def load_flux_dev_pipeline():
|
||||
id = "black-forest-labs/FLUX.1-dev"
|
||||
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
|
||||
)
|
||||
return pipeline
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embeddings(pipeline, prompts, max_sequence_length):
|
||||
all_prompt_embeds = []
|
||||
all_pooled_prompt_embeds = []
|
||||
all_text_ids = []
|
||||
for prompt in tqdm(prompts, desc="Encoding prompts."):
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
text_ids,
|
||||
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
|
||||
all_prompt_embeds.append(prompt_embeds)
|
||||
all_pooled_prompt_embeds.append(pooled_prompt_embeds)
|
||||
all_text_ids.append(text_ids)
|
||||
|
||||
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
|
||||
print(f"Max memory allocated: {max_memory:.3f} GB")
|
||||
return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
|
||||
|
||||
|
||||
def run(args):
|
||||
dataset = load_dataset("Norod78/Yarn-art-style", split="train")
|
||||
image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
|
||||
all_prompts = list(image_prompts.values())
|
||||
print(f"{len(all_prompts)=}")
|
||||
|
||||
pipeline = load_flux_dev_pipeline()
|
||||
all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
|
||||
pipeline, all_prompts, args.max_sequence_length
|
||||
)
|
||||
|
||||
data = []
|
||||
for i, (image_hash, _) in enumerate(image_prompts.items()):
|
||||
data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
|
||||
print(f"{len(data)=}")
|
||||
|
||||
# Create a DataFrame
|
||||
embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
|
||||
df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
|
||||
print(f"{len(df)=}")
|
||||
|
||||
# Convert embedding lists to arrays (for proper storage in parquet)
|
||||
for col in embedding_cols:
|
||||
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
|
||||
|
||||
# Save the dataframe to a parquet file
|
||||
df.to_parquet(args.output_path)
|
||||
print(f"Data successfully serialized to {args.output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--max_sequence_length",
|
||||
type=int,
|
||||
default=MAX_SEQ_LENGTH,
|
||||
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
|
||||
args = parser.parse_args()
|
||||
|
||||
run(args)
|
||||
23
examples/research_projects/flux_lora_quantization/ds2.yaml
Normal file
23
examples/research_projects/flux_lora_quantization/ds2.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
File diff suppressed because it is too large
Load Diff
@@ -195,13 +195,13 @@ class FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
self.default_sample_size = 128
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
@@ -386,8 +386,10 @@ class FluxPipeline(
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -425,9 +427,9 @@ class FluxPipeline(
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
@@ -452,10 +454,10 @@ class FluxPipeline(
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -499,8 +501,8 @@ class FluxPipeline(
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
@@ -517,7 +519,7 @@ class FluxPipeline(
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
@@ -216,13 +216,13 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
controlnet=controlnet,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
self.default_sample_size = 128
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
@@ -410,8 +410,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -450,9 +452,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
@@ -479,10 +481,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -498,8 +500,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
@@ -516,7 +518,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
@@ -228,13 +228,13 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
controlnet=controlnet,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
@@ -453,8 +453,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -493,9 +495,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
@@ -522,10 +524,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -549,11 +551,11 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
@@ -852,7 +854,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
|
||||
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
|
||||
@@ -231,7 +231,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
)
|
||||
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
@@ -244,7 +244,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
@@ -467,8 +467,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -520,9 +522,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
@@ -549,10 +551,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -576,11 +578,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
@@ -622,8 +624,8 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
device,
|
||||
generator,
|
||||
):
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
@@ -930,19 +932,22 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
# xlab controlnet has a input_hint_block and instantx controlnet does not
|
||||
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
|
||||
if self.controlnet.input_hint_block is None:
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# set control mode
|
||||
if control_mode is not None:
|
||||
@@ -952,7 +957,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
# xlab controlnet has a input_hint_block and instantx controlnet does not
|
||||
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
|
||||
for i, control_image_ in enumerate(control_image):
|
||||
control_image_ = self.prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
@@ -964,19 +971,20 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
if self.controlnet.nets[0].input_hint_block is None:
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
control_images.append(control_image_)
|
||||
|
||||
@@ -996,7 +1004,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
# 6. Prepare timesteps
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
|
||||
image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (
|
||||
int(global_width) // self.vae_scale_factor // 2
|
||||
)
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
@@ -1125,6 +1135,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
controlnet_blocks_repeat=controlnet_blocks_repeat,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -212,13 +212,13 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
@@ -437,8 +437,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -477,9 +479,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
@@ -506,10 +508,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -532,11 +534,11 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
@@ -736,7 +738,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
|
||||
# 4.Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
|
||||
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
|
||||
@@ -209,7 +209,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
@@ -222,7 +222,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
@@ -445,8 +445,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -498,9 +500,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
@@ -527,10 +529,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -553,11 +555,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
@@ -598,8 +600,8 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
device,
|
||||
generator,
|
||||
):
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
@@ -866,7 +868,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
|
||||
# 4.Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
|
||||
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
|
||||
@@ -762,8 +762,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
|
||||
@@ -800,8 +800,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
|
||||
@@ -921,8 +921,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
|
||||
@@ -284,7 +284,7 @@ def free_memory():
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.empty_cache()
|
||||
torch_npu.npu.empty_cache()
|
||||
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
|
||||
Reference in New Issue
Block a user