mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Bugfix for dreambooth flux2 img2img2 (#12825)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
@@ -346,7 +346,7 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
required=False,
|
||||
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -835,15 +835,28 @@ class DreamBoothDataset(Dataset):
|
||||
dest_image = self.cond_images[i]
|
||||
image_width, image_height = dest_image.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
|
||||
dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
|
||||
image_width, image_height = dest_image.size
|
||||
|
||||
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
dest_image = Flux2ImageProcessor.image_processor.preprocess(
|
||||
image_processor = Flux2ImageProcessor()
|
||||
dest_image = image_processor.preprocess(
|
||||
dest_image, height=image_height, width=image_width, resize_mode="crop"
|
||||
)
|
||||
# Convert back to PIL
|
||||
dest_image = dest_image.squeeze(0)
|
||||
if dest_image.min() < 0:
|
||||
dest_image = (dest_image + 1) / 2
|
||||
dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
|
||||
|
||||
if dest_image.shape[0] == 1:
|
||||
# Gray scale image
|
||||
dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
|
||||
else:
|
||||
# RGB scale image: (C, H, W) -> (H, W, C)
|
||||
dest_image = TF.to_pil_image(dest_image)
|
||||
|
||||
dest_image = exif_transpose(dest_image)
|
||||
if not dest_image.mode == "RGB":
|
||||
@@ -1463,9 +1476,9 @@ def main(args):
|
||||
args.instance_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
validation_image = load_image(args.validation_image_path).convert("RGB")
|
||||
validation_kwargs = {"image": validation_image}
|
||||
if args.validation_prompt is not None:
|
||||
validation_image = load_image(args.validation_image_path).convert("RGB")
|
||||
validation_kwargs = {"image": validation_image}
|
||||
if args.remote_text_encoder:
|
||||
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user