1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' into remove-explicit-typing

This commit is contained in:
Sayak Paul
2026-01-12 21:36:43 +05:30
committed by GitHub

View File

@@ -1695,9 +1695,13 @@ def main(args):
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
device=cond_model_input.device
)
cond_model_input_ids = cond_model_input_ids.view(
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
@@ -1724,6 +1728,9 @@ def main(args):
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
orig_input_shape = packed_noisy_model_input.shape
orig_input_ids_shape = model_input_ids.shape
# concatenate the model inputs with the cond inputs
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
@@ -1742,7 +1749,8 @@ def main(args):
img_ids=model_input_ids, # B, image_seq_len, 4
return_dict=False,
)[0]
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
model_pred = model_pred[:, : orig_input_shape[1], :]
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)