mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] quality of life improvements in the loading semantics and docs (#3180)
* 👽 qol improvements for LoRA. * better function name? * fix: LoRA weight loading with the new format. * address Patrick's comments. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * change wording around encouraging the use of load_lora_weights(). * fix: function name. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -171,7 +171,7 @@
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/spectrogram_diffusion
|
||||
title: "Spectrogram Diffusion"
|
||||
title: Spectrogram Diffusion
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
@@ -238,6 +238,8 @@
|
||||
title: DPM Discrete Scheduler
|
||||
- local: api/schedulers/dpm_discrete_ancestral
|
||||
title: DPM Discrete Scheduler with ancestral sampling
|
||||
- local: api/schedulers/dpm_sde
|
||||
title: DPMSolverSDEScheduler
|
||||
- local: api/schedulers/euler_ancestral
|
||||
title: Euler Ancestral Scheduler
|
||||
- local: api/schedulers/euler
|
||||
@@ -266,8 +268,6 @@
|
||||
title: VP-SDE
|
||||
- local: api/schedulers/vq_diffusion
|
||||
title: VQDiffusionScheduler
|
||||
- local: api/schedulers/dpm_sde
|
||||
title: DPMSolverSDEScheduler
|
||||
title: Schedulers
|
||||
- sections:
|
||||
- local: api/experimental/rl
|
||||
|
||||
@@ -115,7 +115,7 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
>>> pipe.unet.load_attn_procs(model_path)
|
||||
>>> pipe.unet.load_attn_procs(lora_model_path)
|
||||
>>> pipe.to("cuda")
|
||||
# use half the weights from the LoRA finetuned model and half the weights from the base model
|
||||
|
||||
@@ -128,6 +128,25 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
|
||||
>>> image.save("blue_pokemon.png")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you are loading the LoRA parameters from the Hub and if the Hub repository has
|
||||
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
|
||||
you can do:
|
||||
|
||||
```py
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
|
||||
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
|
||||
card = RepoCard.load(lora_model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
|
||||
...
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## DreamBooth
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type).
|
||||
@@ -208,7 +227,7 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
>>> pipe.unet.load_attn_procs(model_path)
|
||||
>>> pipe.unet.load_attn_procs(lora_model_path)
|
||||
>>> pipe.to("cuda")
|
||||
# use half the weights from the LoRA finetuned model and half the weights from the base model
|
||||
|
||||
@@ -222,4 +241,15 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m
|
||||
|
||||
>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
|
||||
>>> image.save("bucket-dog.png")
|
||||
```
|
||||
```
|
||||
|
||||
Note that the use of [`LoraLoaderMixin.load_lora_weights`] is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because
|
||||
[`LoraLoaderMixin.load_lora_weights`] can handle the following situations:
|
||||
|
||||
* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights(lora_model_path)
|
||||
```
|
||||
|
||||
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
|
||||
@@ -355,7 +355,7 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
|
||||
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
|
||||
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
|
||||
|
||||
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
|
||||
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we
|
||||
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
|
||||
|
||||
With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
|
||||
@@ -387,6 +387,33 @@ Finally, we can run the model in inference.
|
||||
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
|
||||
```
|
||||
|
||||
If you are loading the LoRA parameters from the Hub and if the Hub repository has
|
||||
a `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then
|
||||
you can do:
|
||||
|
||||
```py
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
|
||||
lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example"
|
||||
card = RepoCard.load(lora_model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
|
||||
...
|
||||
```
|
||||
|
||||
**Note** that we will gradually be depcrecating the use of [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) since we now have a more general
|
||||
method to load the LoRA parameters -- [`LoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights). This is because
|
||||
[`LoraLoaderMixin.load_lora_weights`] can handle the following situations:
|
||||
|
||||
* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights(lora_model_path)
|
||||
```
|
||||
|
||||
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
|
||||
|
||||
## Training with Flax/JAX
|
||||
|
||||
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
|
||||
|
||||
@@ -1045,7 +1045,7 @@ def main(args):
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_attn_procs(args.output_dir)
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
|
||||
@@ -281,10 +281,14 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# the names of the keys of the state dict should either start with `unet`
|
||||
# or `text_encoder`.
|
||||
# check `text_encoder` is present at all.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
keys = lora_state_dict.keys()
|
||||
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
|
||||
self.assertTrue(is_text_encoder_present)
|
||||
|
||||
# the names of the keys of the state dict should either start with `unet`
|
||||
# or `text_encoder`.
|
||||
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
|
||||
self.assertTrue(is_correct_naming)
|
||||
|
||||
|
||||
@@ -229,6 +229,21 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
image.save("pokemon.png")
|
||||
```
|
||||
|
||||
If you are loading the LoRA parameters from the Hub and if the Hub repository has
|
||||
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
|
||||
you can do:
|
||||
|
||||
```py
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
|
||||
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
|
||||
card = RepoCard.load(lora_model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
|
||||
...
|
||||
```
|
||||
|
||||
## Training with Flax/JAX
|
||||
|
||||
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -45,6 +46,8 @@ if is_transformers_available():
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
@@ -87,6 +90,9 @@ class AttnProcsLayers(torch.nn.Module):
|
||||
|
||||
|
||||
class UNet2DConditionLoadersMixin:
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
unet_name = UNET_NAME
|
||||
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
r"""
|
||||
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
|
||||
@@ -225,6 +231,18 @@ class UNet2DConditionLoadersMixin:
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
is_new_lora_format = all(
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
)
|
||||
if is_new_lora_format:
|
||||
# Strip the `"unet"` prefix.
|
||||
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
||||
if is_text_encoder_present:
|
||||
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
||||
warnings.warn(warn_message)
|
||||
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
@@ -672,8 +690,8 @@ class LoraLoaderMixin:
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
text_encoder_name = "text_encoder"
|
||||
unet_name = "unet"
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
unet_name = UNET_NAME
|
||||
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
r"""
|
||||
@@ -810,21 +828,24 @@ class LoraLoaderMixin:
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
# Load the layers corresponding to UNet.
|
||||
if all(key.startswith(self.unet_name) for key in keys):
|
||||
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
|
||||
logger.info(f"Loading {self.unet_name}.")
|
||||
unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)}
|
||||
unet_lora_state_dict = {
|
||||
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
||||
}
|
||||
self.unet.load_attn_procs(unet_lora_state_dict)
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
elif all(key.startswith(self.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
||||
logger.info(f"Loading {self.text_encoder_name}.")
|
||||
text_encoder_lora_state_dict = {
|
||||
k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name)
|
||||
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict)
|
||||
self._modify_text_encoder(attn_procs_text_encoder)
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
|
||||
self._modify_text_encoder(attn_procs_text_encoder)
|
||||
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
@@ -832,11 +853,8 @@ class LoraLoaderMixin:
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
):
|
||||
self.unet.load_attn_procs(state_dict)
|
||||
deprecation_message = "You have saved the LoRA weights using the old format. This will be"
|
||||
" deprecated soon. To convert the old LoRA weights to the new format, you can first load them"
|
||||
" in a dictionary and then create a new dictionary like the following:"
|
||||
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
|
||||
deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False)
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
|
||||
warnings.warn(warn_message)
|
||||
|
||||
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
|
||||
r"""
|
||||
@@ -872,7 +890,9 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
return "to_out_lora"
|
||||
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def _load_text_encoder_attn_procs(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
|
||||
):
|
||||
r"""
|
||||
Load pretrained attention processor layers for
|
||||
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
||||
|
||||
Reference in New Issue
Block a user