From 29a930a142bba88ce30eba8e93e512a3e9bdc49c Mon Sep 17 00:00:00 2001 From: Leo Jiang Date: Mon, 12 Jan 2026 07:37:02 -0700 Subject: [PATCH] Bugfix for flux2 img2img2 prediction (#12855) * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 Co-authored-by: tcaimm <93749364+tcaimm@users.noreply.github.com> --------- Co-authored-by: tcaimm <93749364+tcaimm@users.noreply.github.com> --- .../train_dreambooth_lora_flux2_img2img.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 419821e8a8..5af0906642 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1695,9 +1695,13 @@ def main(args): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to( + cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] + cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( device=cond_model_input.device ) + cond_model_input_ids = cond_model_input_ids.view( + cond_model_input.shape[0], -1, model_input_ids.shape[-1] + ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1724,6 +1728,9 @@ def main(args): packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input) packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input) + orig_input_shape = packed_noisy_model_input.shape + orig_input_ids_shape = model_input_ids.shape + # concatenate the model inputs with the cond inputs packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) @@ -1742,7 +1749,8 @@ def main(args): img_ids=model_input_ids, # B, image_seq_len, 4 return_dict=False, )[0] - model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + model_pred = model_pred[:, : orig_input_shape[1], :] + model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :] model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)