mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
minor fix in controlnet flax example (#2986)
* fix the error when push_to_hub but not log validation * contronet_from_pt & controlnet_revision * add intermediate checkpointing to the guide
This commit is contained in:
@@ -320,6 +320,12 @@ Then cd in the example folder and run
|
||||
pip install -U -r requirements_flax.txt
|
||||
```
|
||||
|
||||
If you want to use Weights and Biases logging, you should also install `wandb` now
|
||||
|
||||
```bash
|
||||
pip install wandb
|
||||
```
|
||||
|
||||
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
|
||||
|
||||
```
|
||||
@@ -389,4 +395,17 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
|
||||
|
||||
* [Webdataset](https://webdataset.github.io/webdataset/)
|
||||
* [TorchData](https://github.com/pytorch/data)
|
||||
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
|
||||
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
|
||||
|
||||
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
|
||||
|
||||
```bash
|
||||
--checkpointing_steps=500
|
||||
```
|
||||
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
|
||||
|
||||
You can then start your training from this saved checkpoint with
|
||||
|
||||
```bash
|
||||
--controlnet_model_name_or_path="./control_out/500"
|
||||
```
|
||||
@@ -154,15 +154,16 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
|
||||
|
||||
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, log in enumerate(image_logs):
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
validation_image.save(os.path.join(repo_folder, "image_control.png"))
|
||||
img_str += f"prompt: {validation_prompt}\n"
|
||||
images = [validation_image] + images
|
||||
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
||||
img_str += f"\n"
|
||||
if image_logs is not None:
|
||||
for i, log in enumerate(image_logs):
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
validation_image.save(os.path.join(repo_folder, "image_control.png"))
|
||||
img_str += f"prompt: {validation_prompt}\n"
|
||||
images = [validation_image] + images
|
||||
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
@@ -213,6 +214,17 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Load the pretrained model from a PyTorch checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet_revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Revision of controlnet model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet_from_pt",
|
||||
action="store_true",
|
||||
help="Load the controlnet model from a PyTorch checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
@@ -731,7 +743,10 @@ def main():
|
||||
if args.controlnet_model_name_or_path:
|
||||
logger.info("Loading existing controlnet weights")
|
||||
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
||||
args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32
|
||||
args.controlnet_model_name_or_path,
|
||||
revision=args.controlnet_revision,
|
||||
from_pt=args.controlnet_from_pt,
|
||||
dtype=jnp.float32,
|
||||
)
|
||||
else:
|
||||
logger.info("Initializing controlnet weights from unet")
|
||||
@@ -1021,6 +1036,8 @@ def main():
|
||||
if jax.process_index() == 0:
|
||||
if args.validation_prompt is not None:
|
||||
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
else:
|
||||
image_logs = None
|
||||
|
||||
controlnet.save_pretrained(
|
||||
args.output_dir,
|
||||
|
||||
Reference in New Issue
Block a user