mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Training] QoL improvements in the Flux Control training scripts (#10461)
* qol improvements to the Flux script. * propagate the dataloader changes.
This commit is contained in:
@@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
condition_image=image,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
joint_attention_kwargs={"scale": 0.9},
|
||||
guidance_scale=25.,
|
||||
@@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
condition_image=image,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
@@ -200,5 +200,5 @@ gen_images.save("output.png")
|
||||
## Things to note
|
||||
|
||||
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
|
||||
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
|
||||
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
|
||||
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
|
||||
@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
# need to fix in pipeline_flux_controlnet
|
||||
image = pipeline(
|
||||
prompt=validation_prompt,
|
||||
control_image=validation_image,
|
||||
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
|
||||
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
|
||||
for image in images:
|
||||
image = wandb.Image(image, caption=validation_prompt)
|
||||
formatted_images.append(image)
|
||||
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
# control-lora-{repo_id}
|
||||
# flux-control-{repo_id}
|
||||
|
||||
These are Control weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
@@ -434,7 +433,7 @@ def parse_args(input_args=None):
|
||||
"--conditioning_image_column",
|
||||
type=str,
|
||||
default="conditioning_image",
|
||||
help="The column of the dataset containing the controlnet conditioning image.",
|
||||
help="The column of the dataset containing the control conditioning image.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
@@ -442,6 +441,7 @@ def parse_args(input_args=None):
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=(
|
||||
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
||||
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
|
||||
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
||||
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
||||
" `--validation_image` that will be used with all `--validation_prompt`s."
|
||||
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="Path to the jsonl file containing the training data.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--only_target_transformer_blocks",
|
||||
action="store_true",
|
||||
help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
|
||||
|
||||
if args.resolution % 8 != 0:
|
||||
raise ValueError(
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
|
||||
)
|
||||
|
||||
return args
|
||||
@@ -665,7 +669,12 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
conditioning_images = [image_transforms(image) for image in conditioning_images]
|
||||
examples["pixel_values"] = images
|
||||
examples["conditioning_pixel_values"] = conditioning_images
|
||||
examples["captions"] = list(examples[args.caption_column])
|
||||
|
||||
is_caption_list = isinstance(examples[args.caption_column][0], list)
|
||||
if is_caption_list:
|
||||
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
|
||||
else:
|
||||
examples["captions"] = list(examples[args.caption_column])
|
||||
|
||||
return examples
|
||||
|
||||
@@ -765,7 +774,8 @@ def main(args):
|
||||
subfolder="scheduler",
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
flux_transformer.requires_grad_(True)
|
||||
if not args.only_target_transformer_blocks:
|
||||
flux_transformer.requires_grad_(True)
|
||||
vae.requires_grad_(False)
|
||||
|
||||
# cast down and move to the CPU
|
||||
@@ -797,6 +807,12 @@ def main(args):
|
||||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
|
||||
|
||||
if args.only_target_transformer_blocks:
|
||||
flux_transformer.x_embedder.requires_grad_(True)
|
||||
for name, module in flux_transformer.named_modules():
|
||||
if "transformer_blocks" in name:
|
||||
module.requires_grad_(True)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
@@ -974,6 +990,32 @@ def main(args):
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
|
||||
logger.info("Logging some dataset samples.")
|
||||
formatted_images = []
|
||||
formatted_control_images = []
|
||||
all_prompts = []
|
||||
for i, batch in enumerate(train_dataloader):
|
||||
images = (batch["pixel_values"] + 1) / 2
|
||||
control_images = (batch["conditioning_pixel_values"] + 1) / 2
|
||||
prompts = batch["captions"]
|
||||
|
||||
if len(formatted_images) > 10:
|
||||
break
|
||||
|
||||
for img, control_img, prompt in zip(images, control_images, prompts):
|
||||
formatted_images.append(img)
|
||||
formatted_control_images.append(control_img)
|
||||
all_prompts.append(prompt)
|
||||
|
||||
logged_artifacts = []
|
||||
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
|
||||
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
|
||||
logged_artifacts.append(wandb.Image(img, caption=prompt))
|
||||
|
||||
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
|
||||
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
|
||||
@@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
# need to fix in pipeline_flux_controlnet
|
||||
image = pipeline(
|
||||
prompt=validation_prompt,
|
||||
control_image=validation_image,
|
||||
@@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
|
||||
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
|
||||
for image in images:
|
||||
image = wandb.Image(image, caption=validation_prompt)
|
||||
formatted_images.append(image)
|
||||
@@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
# controlnet-lora-{repo_id}
|
||||
# control-lora-{repo_id}
|
||||
|
||||
These are Control LoRA weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
@@ -256,7 +255,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="controlnet-lora",
|
||||
default="control-lora",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -466,7 +465,7 @@ def parse_args(input_args=None):
|
||||
"--conditioning_image_column",
|
||||
type=str,
|
||||
default="conditioning_image",
|
||||
help="The column of the dataset containing the controlnet conditioning image.",
|
||||
help="The column of the dataset containing the control conditioning image.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
@@ -474,6 +473,7 @@ def parse_args(input_args=None):
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
@@ -500,7 +500,7 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=(
|
||||
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
||||
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
|
||||
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
||||
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
||||
" `--validation_image` that will be used with all `--validation_prompt`s."
|
||||
@@ -613,7 +613,7 @@ def parse_args(input_args=None):
|
||||
|
||||
if args.resolution % 8 != 0:
|
||||
raise ValueError(
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
|
||||
)
|
||||
|
||||
return args
|
||||
@@ -697,7 +697,12 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
conditioning_images = [image_transforms(image) for image in conditioning_images]
|
||||
examples["pixel_values"] = images
|
||||
examples["conditioning_pixel_values"] = conditioning_images
|
||||
examples["captions"] = list(examples[args.caption_column])
|
||||
|
||||
is_caption_list = isinstance(examples[args.caption_column][0], list)
|
||||
if is_caption_list:
|
||||
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
|
||||
else:
|
||||
examples["captions"] = list(examples[args.caption_column])
|
||||
|
||||
return examples
|
||||
|
||||
@@ -1132,6 +1137,32 @@ def main(args):
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
|
||||
logger.info("Logging some dataset samples.")
|
||||
formatted_images = []
|
||||
formatted_control_images = []
|
||||
all_prompts = []
|
||||
for i, batch in enumerate(train_dataloader):
|
||||
images = (batch["pixel_values"] + 1) / 2
|
||||
control_images = (batch["conditioning_pixel_values"] + 1) / 2
|
||||
prompts = batch["captions"]
|
||||
|
||||
if len(formatted_images) > 10:
|
||||
break
|
||||
|
||||
for img, control_img, prompt in zip(images, control_images, prompts):
|
||||
formatted_images.append(img)
|
||||
formatted_control_images.append(control_img)
|
||||
all_prompts.append(prompt)
|
||||
|
||||
logged_artifacts = []
|
||||
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
|
||||
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
|
||||
logged_artifacts.append(wandb.Image(img, caption=prompt))
|
||||
|
||||
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
|
||||
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
|
||||
Reference in New Issue
Block a user