1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] quality of life improvements in the loading semantics and docs (#3180)

* 👽 qol improvements for LoRA.

* better function name?

* fix: LoRA weight loading with the new format.

* address Patrick's comments.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* change wording around encouraging the use of load_lora_weights().

* fix: function name.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sayak Paul
2023-04-28 11:36:49 +05:30
committed by Daniel Gu
parent a80f6966cf
commit 72a84677cb
7 changed files with 123 additions and 27 deletions

View File

@@ -171,7 +171,7 @@
- local: api/pipelines/semantic_stable_diffusion
title: Semantic Guidance
- local: api/pipelines/spectrogram_diffusion
title: "Spectrogram Diffusion"
title: Spectrogram Diffusion
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
@@ -238,6 +238,8 @@
title: DPM Discrete Scheduler
- local: api/schedulers/dpm_discrete_ancestral
title: DPM Discrete Scheduler with ancestral sampling
- local: api/schedulers/dpm_sde
title: DPMSolverSDEScheduler
- local: api/schedulers/euler_ancestral
title: Euler Ancestral Scheduler
- local: api/schedulers/euler
@@ -266,8 +268,6 @@
title: VP-SDE
- local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler
- local: api/schedulers/dpm_sde
title: DPMSolverSDEScheduler
title: Schedulers
- sections:
- local: api/experimental/rl

View File

@@ -115,7 +115,7 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
</Tip>
```py
>>> pipe.unet.load_attn_procs(model_path)
>>> pipe.unet.load_attn_procs(lora_model_path)
>>> pipe.to("cuda")
# use half the weights from the LoRA finetuned model and half the weights from the base model
@@ -128,6 +128,25 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
>>> image.save("blue_pokemon.png")
```
<Tip>
If you are loading the LoRA parameters from the Hub and if the Hub repository has
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
you can do:
```py
from huggingface_hub.repocard import RepoCard
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
...
```
</Tip>
## DreamBooth
[DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type).
@@ -208,7 +227,7 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m
</Tip>
```py
>>> pipe.unet.load_attn_procs(model_path)
>>> pipe.unet.load_attn_procs(lora_model_path)
>>> pipe.to("cuda")
# use half the weights from the LoRA finetuned model and half the weights from the base model
@@ -222,4 +241,15 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m
>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
>>> image.save("bucket-dog.png")
```
```
Note that the use of [`LoraLoaderMixin.load_lora_weights`] is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because
[`LoraLoaderMixin.load_lora_weights`] can handle the following situations:
* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:
```py
pipe.load_lora_weights(lora_model_path)
```
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).

View File

@@ -355,7 +355,7 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
@@ -387,6 +387,33 @@ Finally, we can run the model in inference.
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
```
If you are loading the LoRA parameters from the Hub and if the Hub repository has
a `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then
you can do:
```py
from huggingface_hub.repocard import RepoCard
lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
...
```
**Note** that we will gradually be depcrecating the use of [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) since we now have a more general
method to load the LoRA parameters -- [`LoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights). This is because
[`LoraLoaderMixin.load_lora_weights`] can handle the following situations:
* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:
```py
pipe.load_lora_weights(lora_model_path)
```
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
## Training with Flax/JAX
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

View File

@@ -1045,7 +1045,7 @@ def main(args):
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.load_attn_procs(args.output_dir)
pipeline.load_lora_weights(args.output_dir)
# run inference
if args.validation_prompt and args.num_validation_images > 0:

View File

@@ -281,10 +281,14 @@ class ExamplesTestsAccelerate(unittest.TestCase):
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
# check `text_encoder` is present at all.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present)
# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming)

View File

@@ -229,6 +229,21 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png")
```
If you are loading the LoRA parameters from the Hub and if the Hub repository has
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
you can do:
```py
from huggingface_hub.repocard import RepoCard
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
...
```
## Training with Flax/JAX
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
@@ -45,6 +46,8 @@ if is_transformers_available():
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
@@ -87,6 +90,9 @@ class AttnProcsLayers(torch.nn.Module):
class UNet2DConditionLoadersMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
@@ -225,6 +231,18 @@ class UNet2DConditionLoadersMixin:
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
if is_lora:
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
)
if is_new_lora_format:
# Strip the `"unet"` prefix.
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
warnings.warn(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
@@ -672,8 +690,8 @@ class LoraLoaderMixin:
</Tip>
"""
text_encoder_name = "text_encoder"
unet_name = "unet"
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
@@ -810,21 +828,24 @@ class LoraLoaderMixin:
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
# Load the layers corresponding to UNet.
if all(key.startswith(self.unet_name) for key in keys):
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
logger.info(f"Loading {self.unet_name}.")
unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)}
unet_lora_state_dict = {
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
self.unet.load_attn_procs(unet_lora_state_dict)
# Load the layers corresponding to text encoder and make necessary adjustments.
elif all(key.startswith(self.text_encoder_name) for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
logger.info(f"Loading {self.text_encoder_name}.")
text_encoder_lora_state_dict = {
k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name)
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict)
self._modify_text_encoder(attn_procs_text_encoder)
if len(text_encoder_lora_state_dict) > 0:
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
self._modify_text_encoder(attn_procs_text_encoder)
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
@@ -832,11 +853,8 @@ class LoraLoaderMixin:
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
):
self.unet.load_attn_procs(state_dict)
deprecation_message = "You have saved the LoRA weights using the old format. This will be"
" deprecated soon. To convert the old LoRA weights to the new format, you can first load them"
" in a dictionary and then create a new dictionary like the following:"
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False)
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
r"""
@@ -872,7 +890,9 @@ class LoraLoaderMixin:
else:
return "to_out_lora"
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def _load_text_encoder_attn_procs(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
):
r"""
Load pretrained attention processor layers for
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).