mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
||||
- [`StableDiffusionXLInstructPix2PixPipeline`]
|
||||
- [`StableDiffusionXLControlNetPipeline`]
|
||||
- [`StableDiffusionXLKDiffusionPipeline`]
|
||||
- [`StableDiffusion3Pipeline`]
|
||||
- [`LatentConsistencyModelPipeline`]
|
||||
- [`LatentConsistencyModelImg2ImgPipeline`]
|
||||
- [`StableDiffusionControlNetXSPipeline`]
|
||||
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
||||
- [`StableCascadeUNet`]
|
||||
- [`AutoencoderKL`]
|
||||
- [`ControlNetModel`]
|
||||
- [`SD3Transformer2DModel`]
|
||||
|
||||
## FromSingleFileMixin
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ The abstract from the paper is:
|
||||
|
||||
## Usage Example
|
||||
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
@@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
|
||||
|
||||
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
from transformers import T5EncoderModel
|
||||
### Loading the single file checkpoint without T5
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
|
||||
torch_dtype=torch.float16,
|
||||
text_encoder_3=None
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
|
||||
image.save('sd3-single-file.png')
|
||||
```
|
||||
|
||||
<Tip>
|
||||
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
|
||||
</Tip>
|
||||
### Loading the single file checkpoint without T5
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
|
||||
image.save('sd3-single-file-t5-fp8.png')
|
||||
```
|
||||
|
||||
## StableDiffusion3Pipeline
|
||||
|
||||
|
||||
@@ -28,9 +28,11 @@ from .single_file_utils import (
|
||||
_legacy_load_safety_checker,
|
||||
_legacy_load_scheduler,
|
||||
create_diffusers_clip_model_from_ldm,
|
||||
create_diffusers_t5_model_from_checkpoint,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
is_clip_model_in_single_file,
|
||||
is_t5_in_single_file,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
||||
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
elif is_tokenizer and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
|
||||
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = [
|
||||
"cond_stage_model.transformer.",
|
||||
"conditioner.embedders.0.transformer.",
|
||||
"text_encoders.clip_l.transformer.",
|
||||
]
|
||||
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
||||
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
||||
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
|
||||
|
||||
|
||||
def is_open_clip_sd3_model(checkpoint):
|
||||
is_open_clip_sdxl_refiner_model(checkpoint)
|
||||
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_open_clip_sdxl_refiner_model(checkpoint):
|
||||
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
|
||||
keys = list(checkpoint.keys())
|
||||
text_model_dict = {}
|
||||
|
||||
remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
|
||||
remove_prefixes = []
|
||||
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
|
||||
if remove_prefix:
|
||||
remove_prefixes.append(remove_prefix)
|
||||
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
@@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
|
||||
):
|
||||
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
|
||||
|
||||
elif (
|
||||
is_clip_sd3_model(checkpoint)
|
||||
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
|
||||
):
|
||||
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
|
||||
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
|
||||
|
||||
elif is_open_clip_model(checkpoint):
|
||||
prefix = "cond_stage_model.model."
|
||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||
@@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
|
||||
prefix = "conditioner.embedders.0.model."
|
||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||
|
||||
elif is_open_clip_sd3_model(checkpoint):
|
||||
prefix = "text_encoders.clip_g.transformer."
|
||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||
elif (
|
||||
is_open_clip_sd3_model(checkpoint)
|
||||
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
|
||||
):
|
||||
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
|
||||
|
||||
else:
|
||||
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
||||
@@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
text_model_dict = {}
|
||||
|
||||
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
|
||||
remove_prefixes = ["text_encoders.t5xxl.transformer."]
|
||||
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
|
||||
Reference in New Issue
Block a user