1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[training ] add Kontext i2i training (#11858)

* feat: enable i2i fine-tuning in Kontext script.

* readme

* more checks.

* Apply suggestions from code review

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>

* fixes

* fix

* add proj_mlp to the mix

* Update README_flux.md

add note on installing from commit `05e7a854d0a5661f5b433f6dd5954c224b104f0b`

* fix

* fix

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
Sayak Paul
2025-07-08 21:04:16 +05:30
committed by GitHub
parent ce338d4e4a
commit 01240fecb0
2 changed files with 229 additions and 44 deletions

View File

@@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N
## Training Kontext
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
**important**
> [!NOTE]
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
> To do this, execute the following steps in a new virtual environment:
> ```
> git clone https://github.com/huggingface/diffusers
> cd diffusers
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
> pip install -e .
> ```
Below is an example training command:
@@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
perform as expected.
Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
* Condition image
* Target image
* Instruction
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
```bash
accelerate launch train_dreambooth_lora_flux_kontext.py \
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
--output_dir="kontext-i2i" \
--dataset_name="kontext-community/relighting" \
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--optimizer="adamw" \
--use_8bit_adam \
--cache_latents \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=200 \
--max_train_steps=1000 \
--rank=16\
--seed="0"
```
More generally, when performing I2I fine-tuning, we expect you to:
* Have a dataset `kontext-community/relighting`
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
### Misc notes
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
@@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
## Other notes
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️

View File

@@ -40,7 +40,7 @@ from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from torchvision import transforms
from torchvision.transforms.functional import crop
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
@@ -62,11 +62,7 @@ from diffusers.training_utils import (
free_memory,
parse_buckets_string,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -186,6 +182,7 @@ def log_validation(
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
pipeline_args_cp = pipeline_args.copy()
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
@@ -193,14 +190,16 @@ def log_validation(
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
prompt = pipeline_args_cp.pop("prompt")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
**pipeline_args_cp,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
generator=generator,
).images[0]
images.append(image)
@@ -310,6 +309,12 @@ def parse_args(input_args=None):
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
parser.add_argument(
"--cond_image_column",
type=str,
default=None,
help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
)
parser.add_argument(
"--caption_column",
type=str,
@@ -330,7 +335,6 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -351,6 +355,12 @@ def parse_args(input_args=None):
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.",
)
parser.add_argument(
"--num_validation_images",
type=int,
@@ -399,7 +409,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
default="flux-dreambooth-lora",
default="flux-kontext-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -716,6 +726,8 @@ def parse_args(input_args=None):
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
if args.cond_image_column is not None:
raise ValueError("Prior preservation isn't supported with I2I training.")
else:
# logger is not available yet
if args.class_data_dir is not None:
@@ -723,6 +735,14 @@ def parse_args(input_args=None):
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
if args.cond_image_column is not None:
assert args.image_column is not None
assert args.caption_column is not None
assert args.dataset_name is not None
assert not args.train_text_encoder
if args.validation_prompt is not None:
assert args.validation_image is None and os.path.exists(args.validation_image)
return args
@@ -742,6 +762,7 @@ class DreamBoothDataset(Dataset):
repeats=1,
center_crop=False,
buckets=None,
args=None,
):
self.center_crop = center_crop
@@ -774,6 +795,10 @@ class DreamBoothDataset(Dataset):
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.cond_image_column is not None and args.cond_image_column not in column_names:
raise ValueError(
f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
@@ -783,7 +808,12 @@ class DreamBoothDataset(Dataset):
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))]
cond_images = None
cond_image_column = args.cond_image_column
if cond_image_column is not None:
cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
assert len(instance_images) == len(cond_images)
if args.caption_column is None:
logger.info(
@@ -811,14 +841,23 @@ class DreamBoothDataset(Dataset):
self.custom_instance_prompts = None
self.instance_images = []
for img in instance_images:
self.cond_images = []
for i, img in enumerate(instance_images):
self.instance_images.extend(itertools.repeat(img, repeats))
if args.dataset_name is not None and cond_images is not None:
self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
self.pixel_values = []
for image in self.instance_images:
self.cond_pixel_values = []
for i, image in enumerate(self.instance_images):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
dest_image = None
if self.cond_images:
dest_image = exif_transpose(self.cond_images[i])
if not dest_image.mode == "RGB":
dest_image = dest_image.convert("RGB")
width, height = image.size
@@ -828,25 +867,16 @@ class DreamBoothDataset(Dataset):
self.size = (target_height, target_width)
# based on the bucket assignment, define the transformations
train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
image, dest_image = self.paired_transform(
image,
dest_image=dest_image,
size=self.size,
center_crop=args.center_crop,
random_flip=args.random_flip,
)
image = train_resize(image)
if args.center_crop:
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, self.size)
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
image = train_flip(image)
image = train_transforms(image)
self.pixel_values.append((image, bucket_idx))
if dest_image is not None:
self.cond_pixel_values.append((dest_image, bucket_idx))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
@@ -880,6 +910,9 @@ class DreamBoothDataset(Dataset):
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
if self.cond_pixel_values:
dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
example["cond_images"] = dest_image
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -902,6 +935,43 @@ class DreamBoothDataset(Dataset):
return example
def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
# 1. Resize (deterministic)
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
image = resize(image)
if dest_image is not None:
dest_image = resize(dest_image)
# 2. Crop: either center or SAME random crop
if center_crop:
crop = transforms.CenterCrop(size)
image = crop(image)
if dest_image is not None:
dest_image = crop(dest_image)
else:
# get_params returns (i, j, h, w)
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
image = TF.crop(image, i, j, h, w)
if dest_image is not None:
dest_image = TF.crop(dest_image, i, j, h, w)
# 3. Random horizontal flip with the SAME coin flip
if random_flip:
do_flip = random.random() < 0.5
if do_flip:
image = TF.hflip(image)
if dest_image is not None:
dest_image = TF.hflip(dest_image)
# 4. ToTensor + Normalize (deterministic)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize([0.5], [0.5])
image = normalize(to_tensor(image))
if dest_image is not None:
dest_image = normalize(to_tensor(dest_image))
return (image, dest_image) if dest_image is not None else (image, None)
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
@@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts}
if any("cond_images" in example for example in examples):
cond_pixel_values = [example["cond_images"] for example in examples]
cond_pixel_values = torch.stack(cond_pixel_values)
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
batch.update({"cond_pixel_values": cond_pixel_values})
return batch
@@ -1318,6 +1393,7 @@ def main(args):
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
"proj_mlp",
]
# now we will add new LoRA weights the transformer layers
@@ -1534,7 +1610,10 @@ def main(args):
buckets=buckets,
repeats=args.repeats,
center_crop=args.center_crop,
args=args,
)
if args.cond_image_column is not None:
logger.info("I2I fine-tuning enabled.")
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
@@ -1574,6 +1653,7 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
text_encoder_one.cpu(), text_encoder_two.cpu()
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
@@ -1605,19 +1685,41 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
cached_text_embeddings = []
for batch in tqdm(train_dataloader, desc="Embedding prompts"):
batch_prompts = batch["prompts"]
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
batch_prompts, text_encoders, tokenizers
)
cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
if args.validation_prompt is None:
text_encoder_one.cpu(), text_encoder_two.cpu()
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
has_image_input = args.cond_image_column is not None
if args.cache_latents:
latents_cache = []
cond_latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if has_image_input:
batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
if args.validation_prompt is None:
vae.cpu()
del vae
free_memory()
@@ -1678,7 +1780,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = "dreambooth-flux-dev-lora"
tracker_name = "dreambooth-flux-kontext-lora"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
@@ -1742,6 +1844,7 @@ def main(args):
sigma = sigma.unsqueeze(-1)
return sigma
has_guidance = unwrap_model(transformer).config.guidance_embeds
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
@@ -1759,9 +1862,7 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
tokens_two = tokenize_prompt(
@@ -1794,16 +1895,29 @@ def main(args):
if args.cache_latents:
if args.vae_encode_mode == "sample":
model_input = latents_cache[step].sample()
if has_image_input:
cond_model_input = cond_latents_cache[step].sample()
else:
model_input = latents_cache[step].mode()
if has_image_input:
cond_model_input = cond_latents_cache[step].mode()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
if has_image_input:
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
if args.vae_encode_mode == "sample":
model_input = vae.encode(pixel_values).latent_dist.sample()
if has_image_input:
cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
else:
model_input = vae.encode(pixel_values).latent_dist.mode()
if has_image_input:
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
if has_image_input:
cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor
cond_model_input = cond_model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
@@ -1814,6 +1928,17 @@ def main(args):
accelerator.device,
weight_dtype,
)
if has_image_input:
cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(
cond_model_input.shape[0],
cond_model_input.shape[2] // 2,
cond_model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
cond_latents_ids[..., 0] = 1
latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
@@ -1834,7 +1959,6 @@ def main(args):
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
packed_noisy_model_input = FluxKontextPipeline._pack_latents(
noisy_model_input,
batch_size=model_input.shape[0],
@@ -1842,13 +1966,22 @@ def main(args):
height=model_input.shape[2],
width=model_input.shape[3],
)
orig_inp_shape = packed_noisy_model_input.shape
if has_image_input:
packed_cond_input = FluxKontextPipeline._pack_latents(
cond_model_input,
batch_size=cond_model_input.shape[0],
num_channels_latents=cond_model_input.shape[1],
height=cond_model_input.shape[2],
width=cond_model_input.shape[3],
)
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
# handle guidance
if unwrap_model(transformer).config.guidance_embeds:
# Kontext always has guidance
guidance = None
if has_guidance:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
guidance = None
# Predict the noise residual
model_pred = transformer(
@@ -1862,6 +1995,8 @@ def main(args):
img_ids=latent_image_ids,
return_dict=False,
)[0]
if has_image_input:
model_pred = model_pred[:, : orig_inp_shape[1]]
model_pred = FluxKontextPipeline._unpack_latents(
model_pred,
height=model_input.shape[2] * vae_scale_factor,
@@ -1970,6 +2105,8 @@ def main(args):
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}
if has_image_input and args.validation_image:
pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,
@@ -2030,6 +2167,8 @@ def main(args):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
if has_image_input and args.validation_image:
pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,