mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[training ] add Kontext i2i training (#11858)
* feat: enable i2i fine-tuning in Kontext script. * readme * more checks. * Apply suggestions from code review Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> * fixes * fix * add proj_mlp to the mix * Update README_flux.md add note on installing from commit `05e7a854d0a5661f5b433f6dd5954c224b104f0b` * fix * fix --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
@@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N
|
||||
## Training Kontext
|
||||
|
||||
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
|
||||
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
|
||||
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
|
||||
|
||||
Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
|
||||
**important**
|
||||
|
||||
> [!NOTE]
|
||||
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
|
||||
> To do this, execute the following steps in a new virtual environment:
|
||||
> ```
|
||||
> git clone https://github.com/huggingface/diffusers
|
||||
> cd diffusers
|
||||
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
|
||||
> pip install -e .
|
||||
> ```
|
||||
|
||||
Below is an example training command:
|
||||
|
||||
@@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \
|
||||
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
|
||||
perform as expected.
|
||||
|
||||
Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
|
||||
|
||||
* Condition image
|
||||
* Target image
|
||||
* Instruction
|
||||
|
||||
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_flux_kontext.py \
|
||||
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
|
||||
--output_dir="kontext-i2i" \
|
||||
--dataset_name="kontext-community/relighting" \
|
||||
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--optimizer="adamw" \
|
||||
--use_8bit_adam \
|
||||
--cache_latents \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=200 \
|
||||
--max_train_steps=1000 \
|
||||
--rank=16\
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
More generally, when performing I2I fine-tuning, we expect you to:
|
||||
|
||||
* Have a dataset `kontext-community/relighting`
|
||||
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
|
||||
|
||||
### Misc notes
|
||||
|
||||
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
|
||||
@@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a
|
||||
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
|
||||
## Other notes
|
||||
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
|
||||
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
|
||||
|
||||
@@ -40,7 +40,7 @@ from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.sampler import BatchSampler
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from torchvision.transforms import functional as TF
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
|
||||
|
||||
@@ -62,11 +62,7 @@ from diffusers.training_utils import (
|
||||
free_memory,
|
||||
parse_buckets_string,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -186,6 +182,7 @@ def log_validation(
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
pipeline_args_cp = pipeline_args.copy()
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
@@ -193,14 +190,16 @@ def log_validation(
|
||||
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
|
||||
)
|
||||
prompt = pipeline_args_cp.pop("prompt")
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
|
||||
**pipeline_args_cp,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
@@ -310,6 +309,12 @@ def parse_args(input_args=None):
|
||||
"default, the standard Image Dataset maps out 'file_name' "
|
||||
"to 'image'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_image_column",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
@@ -330,7 +335,6 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -351,6 +355,12 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_validation_images",
|
||||
type=int,
|
||||
@@ -399,7 +409,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="flux-dreambooth-lora",
|
||||
default="flux-kontext-lora",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
@@ -716,6 +726,8 @@ def parse_args(input_args=None):
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
if args.cond_image_column is not None:
|
||||
raise ValueError("Prior preservation isn't supported with I2I training.")
|
||||
else:
|
||||
# logger is not available yet
|
||||
if args.class_data_dir is not None:
|
||||
@@ -723,6 +735,14 @@ def parse_args(input_args=None):
|
||||
if args.class_prompt is not None:
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
if args.cond_image_column is not None:
|
||||
assert args.image_column is not None
|
||||
assert args.caption_column is not None
|
||||
assert args.dataset_name is not None
|
||||
assert not args.train_text_encoder
|
||||
if args.validation_prompt is not None:
|
||||
assert args.validation_image is None and os.path.exists(args.validation_image)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -742,6 +762,7 @@ class DreamBoothDataset(Dataset):
|
||||
repeats=1,
|
||||
center_crop=False,
|
||||
buckets=None,
|
||||
args=None,
|
||||
):
|
||||
self.center_crop = center_crop
|
||||
|
||||
@@ -774,6 +795,10 @@ class DreamBoothDataset(Dataset):
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
if args.cond_image_column is not None and args.cond_image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
if args.image_column is None:
|
||||
image_column = column_names[0]
|
||||
logger.info(f"image column defaulting to {image_column}")
|
||||
@@ -783,7 +808,12 @@ class DreamBoothDataset(Dataset):
|
||||
raise ValueError(
|
||||
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
instance_images = dataset["train"][image_column]
|
||||
instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))]
|
||||
cond_images = None
|
||||
cond_image_column = args.cond_image_column
|
||||
if cond_image_column is not None:
|
||||
cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
|
||||
assert len(instance_images) == len(cond_images)
|
||||
|
||||
if args.caption_column is None:
|
||||
logger.info(
|
||||
@@ -811,14 +841,23 @@ class DreamBoothDataset(Dataset):
|
||||
self.custom_instance_prompts = None
|
||||
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.cond_images = []
|
||||
for i, img in enumerate(instance_images):
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
if args.dataset_name is not None and cond_images is not None:
|
||||
self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
|
||||
|
||||
self.pixel_values = []
|
||||
for image in self.instance_images:
|
||||
self.cond_pixel_values = []
|
||||
for i, image in enumerate(self.instance_images):
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
dest_image = None
|
||||
if self.cond_images:
|
||||
dest_image = exif_transpose(self.cond_images[i])
|
||||
if not dest_image.mode == "RGB":
|
||||
dest_image = dest_image.convert("RGB")
|
||||
|
||||
width, height = image.size
|
||||
|
||||
@@ -828,25 +867,16 @@ class DreamBoothDataset(Dataset):
|
||||
self.size = (target_height, target_width)
|
||||
|
||||
# based on the bucket assignment, define the transformations
|
||||
train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
image, dest_image = self.paired_transform(
|
||||
image,
|
||||
dest_image=dest_image,
|
||||
size=self.size,
|
||||
center_crop=args.center_crop,
|
||||
random_flip=args.random_flip,
|
||||
)
|
||||
image = train_resize(image)
|
||||
if args.center_crop:
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, self.size)
|
||||
image = crop(image, y1, x1, h, w)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
image = train_flip(image)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append((image, bucket_idx))
|
||||
if dest_image is not None:
|
||||
self.cond_pixel_values.append((dest_image, bucket_idx))
|
||||
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
@@ -880,6 +910,9 @@ class DreamBoothDataset(Dataset):
|
||||
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["bucket_idx"] = bucket_idx
|
||||
if self.cond_pixel_values:
|
||||
dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
|
||||
example["cond_images"] = dest_image
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -902,6 +935,43 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
return example
|
||||
|
||||
def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
|
||||
# 1. Resize (deterministic)
|
||||
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = resize(image)
|
||||
if dest_image is not None:
|
||||
dest_image = resize(dest_image)
|
||||
|
||||
# 2. Crop: either center or SAME random crop
|
||||
if center_crop:
|
||||
crop = transforms.CenterCrop(size)
|
||||
image = crop(image)
|
||||
if dest_image is not None:
|
||||
dest_image = crop(dest_image)
|
||||
else:
|
||||
# get_params returns (i, j, h, w)
|
||||
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
|
||||
image = TF.crop(image, i, j, h, w)
|
||||
if dest_image is not None:
|
||||
dest_image = TF.crop(dest_image, i, j, h, w)
|
||||
|
||||
# 3. Random horizontal flip with the SAME coin flip
|
||||
if random_flip:
|
||||
do_flip = random.random() < 0.5
|
||||
if do_flip:
|
||||
image = TF.hflip(image)
|
||||
if dest_image is not None:
|
||||
dest_image = TF.hflip(dest_image)
|
||||
|
||||
# 4. ToTensor + Normalize (deterministic)
|
||||
to_tensor = transforms.ToTensor()
|
||||
normalize = transforms.Normalize([0.5], [0.5])
|
||||
image = normalize(to_tensor(image))
|
||||
if dest_image is not None:
|
||||
dest_image = normalize(to_tensor(dest_image))
|
||||
|
||||
return (image, dest_image) if dest_image is not None else (image, None)
|
||||
|
||||
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
@@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
if any("cond_images" in example for example in examples):
|
||||
cond_pixel_values = [example["cond_images"] for example in examples]
|
||||
cond_pixel_values = torch.stack(cond_pixel_values)
|
||||
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
batch.update({"cond_pixel_values": cond_pixel_values})
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1318,6 +1393,7 @@ def main(args):
|
||||
"ff.net.2",
|
||||
"ff_context.net.0.proj",
|
||||
"ff_context.net.2",
|
||||
"proj_mlp",
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
@@ -1534,7 +1610,10 @@ def main(args):
|
||||
buckets=buckets,
|
||||
repeats=args.repeats,
|
||||
center_crop=args.center_crop,
|
||||
args=args,
|
||||
)
|
||||
if args.cond_image_column is not None:
|
||||
logger.info("I2I fine-tuning enabled.")
|
||||
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
@@ -1574,6 +1653,7 @@ def main(args):
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
text_encoder_one.cpu(), text_encoder_two.cpu()
|
||||
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
free_memory()
|
||||
|
||||
@@ -1605,19 +1685,41 @@ def main(args):
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
|
||||
cached_text_embeddings = []
|
||||
for batch in tqdm(train_dataloader, desc="Embedding prompts"):
|
||||
batch_prompts = batch["prompts"]
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
|
||||
batch_prompts, text_encoders, tokenizers
|
||||
)
|
||||
cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
|
||||
|
||||
if args.validation_prompt is None:
|
||||
text_encoder_one.cpu(), text_encoder_two.cpu()
|
||||
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
free_memory()
|
||||
|
||||
vae_config_shift_factor = vae.config.shift_factor
|
||||
vae_config_scaling_factor = vae.config.scaling_factor
|
||||
vae_config_block_out_channels = vae.config.block_out_channels
|
||||
has_image_input = args.cond_image_column is not None
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
cond_latents_cache = []
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=weight_dtype
|
||||
)
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
if has_image_input:
|
||||
batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=weight_dtype
|
||||
)
|
||||
cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
|
||||
|
||||
if args.validation_prompt is None:
|
||||
vae.cpu()
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
@@ -1678,7 +1780,7 @@ def main(args):
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_name = "dreambooth-flux-dev-lora"
|
||||
tracker_name = "dreambooth-flux-kontext-lora"
|
||||
accelerator.init_trackers(tracker_name, config=vars(args))
|
||||
|
||||
# Train!
|
||||
@@ -1742,6 +1844,7 @@ def main(args):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
has_guidance = unwrap_model(transformer).config.guidance_embeds
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
transformer.train()
|
||||
if args.train_text_encoder:
|
||||
@@ -1759,9 +1862,7 @@ def main(args):
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
)
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
|
||||
tokens_two = tokenize_prompt(
|
||||
@@ -1794,16 +1895,29 @@ def main(args):
|
||||
if args.cache_latents:
|
||||
if args.vae_encode_mode == "sample":
|
||||
model_input = latents_cache[step].sample()
|
||||
if has_image_input:
|
||||
cond_model_input = cond_latents_cache[step].sample()
|
||||
else:
|
||||
model_input = latents_cache[step].mode()
|
||||
if has_image_input:
|
||||
cond_model_input = cond_latents_cache[step].mode()
|
||||
else:
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
if has_image_input:
|
||||
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
|
||||
if args.vae_encode_mode == "sample":
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
if has_image_input:
|
||||
cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
|
||||
else:
|
||||
model_input = vae.encode(pixel_values).latent_dist.mode()
|
||||
if has_image_input:
|
||||
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
if has_image_input:
|
||||
cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
cond_model_input = cond_model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
||||
|
||||
@@ -1814,6 +1928,17 @@ def main(args):
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
if has_image_input:
|
||||
cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(
|
||||
cond_model_input.shape[0],
|
||||
cond_model_input.shape[2] // 2,
|
||||
cond_model_input.shape[3] // 2,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
cond_latents_ids[..., 0] = 1
|
||||
latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
bsz = model_input.shape[0]
|
||||
@@ -1834,7 +1959,6 @@ def main(args):
|
||||
# zt = (1 - texp) * x + texp * z1
|
||||
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
packed_noisy_model_input = FluxKontextPipeline._pack_latents(
|
||||
noisy_model_input,
|
||||
batch_size=model_input.shape[0],
|
||||
@@ -1842,13 +1966,22 @@ def main(args):
|
||||
height=model_input.shape[2],
|
||||
width=model_input.shape[3],
|
||||
)
|
||||
orig_inp_shape = packed_noisy_model_input.shape
|
||||
if has_image_input:
|
||||
packed_cond_input = FluxKontextPipeline._pack_latents(
|
||||
cond_model_input,
|
||||
batch_size=cond_model_input.shape[0],
|
||||
num_channels_latents=cond_model_input.shape[1],
|
||||
height=cond_model_input.shape[2],
|
||||
width=cond_model_input.shape[3],
|
||||
)
|
||||
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
|
||||
|
||||
# handle guidance
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
# Kontext always has guidance
|
||||
guidance = None
|
||||
if has_guidance:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = transformer(
|
||||
@@ -1862,6 +1995,8 @@ def main(args):
|
||||
img_ids=latent_image_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if has_image_input:
|
||||
model_pred = model_pred[:, : orig_inp_shape[1]]
|
||||
model_pred = FluxKontextPipeline._unpack_latents(
|
||||
model_pred,
|
||||
height=model_input.shape[2] * vae_scale_factor,
|
||||
@@ -1970,6 +2105,8 @@ def main(args):
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if has_image_input and args.validation_image:
|
||||
pipeline_args.update({"image": load_image(args.validation_image)})
|
||||
images = log_validation(
|
||||
pipeline=pipeline,
|
||||
args=args,
|
||||
@@ -2030,6 +2167,8 @@ def main(args):
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if has_image_input and args.validation_image:
|
||||
pipeline_args.update({"image": load_image(args.validation_image)})
|
||||
images = log_validation(
|
||||
pipeline=pipeline,
|
||||
args=args,
|
||||
|
||||
Reference in New Issue
Block a user