1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-04-09 15:34:42 +05:30
parent 644147a198
commit b365801c57
8 changed files with 41 additions and 34 deletions

View File

@@ -35,8 +35,7 @@ import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import (DistributedDataParallelKwargs,
ProjectConfiguration, set_seed)
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
@@ -52,24 +51,31 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import diffusers
from diffusers import (AutoencoderKL, DDPMScheduler,
DPMSolverMultistepScheduler, EDMEulerScheduler,
EulerDiscreteScheduler, StableDiffusionXLPipeline,
UNet2DConditionModel)
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMEulerScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (_set_state_dict_into_text_encoder,
cast_training_params, compute_snr)
from diffusers.utils import (check_min_version, convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
is_wandb_available)
from diffusers.utils.hub_utils import (load_or_create_model_card,
populate_model_card)
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import (
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb

View File

@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
return_tensors="pt",
)
tokens = batch_encoding["input_ids"]
assert (
torch.count_nonzero(tokens - 49407) == 2
), f"String '{string}' maps to more than a single token. Please use another string"
assert torch.count_nonzero(tokens - 49407) == 2, (
f"String '{string}' maps to more than a single token. Please use another string"
)
return tokens[0, 1]

View File

@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
assert H == self.img_size[0] and W == self.img_size[1], (
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
)
x = self.proj(x).flatten(2).permute(0, 2, 1)
return x

View File

@@ -26,6 +26,7 @@ _import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
@@ -41,7 +42,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]

View File

@@ -19,8 +19,7 @@ import ftfy
import PIL
import regex as re
import torch
from transformers import (AutoTokenizer, CLIPImageProcessor, CLIPVisionModel,
UMT5EncoderModel)
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
@@ -33,6 +32,7 @@ from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import WanPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -334,7 +334,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
)
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")

View File

@@ -29,6 +29,7 @@ from packaging.version import Version, parse
from . import logging
# The package importlib_metadata is in a different place, depending on the python version.
if sys.version_info < (3, 8):
import importlib_metadata

View File

@@ -2639,12 +2639,12 @@ class FasterCacheTesterMixin:
output = run_forward(pipe).flatten()
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(
original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
), "FasterCache outputs should not differ much in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
assert np.allclose(original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol), (
"FasterCache outputs should not differ much in specified timestep range."
)
assert np.allclose(original_image_slice, image_slice_faster_cache_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
def test_faster_cache_state(self):
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK

View File

@@ -47,9 +47,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths: