1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Patrick von Platen
2023-06-10 17:04:59 +02:00
21 changed files with 41 additions and 69 deletions

View File

@@ -60,8 +60,7 @@ returns both the image embeddings corresponding to the prompt and negative/uncon
embeddings corresponding to an empty string.
```py
generator = torch.Generator(device="cuda").manual_seed(12)
image_embeds, negative_image_embeds = pipe_prior(prompt, generator=generator).to_tuple()
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
```
<Tip warning={true}>
@@ -78,7 +77,7 @@ of the prior by a factor of 2.
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
negative_prompt = "low quality, bad quality"
image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, generator=generator).to_tuple()
image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, guidance_scale=1.0).to_tuple()
```
</Tip>
@@ -89,7 +88,9 @@ in case you are using a customized negative prompt, that you should pass this on
with `negative_prompt=negative_prompt`:
```py
image = t2i_pipe(prompt, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0]
image = t2i_pipe(
prompt, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768
).images[0]
image.save("cheeseburger_monster.png")
```
@@ -160,8 +161,7 @@ pipe.to("cuda")
prompt = "A fantasy landscape, Cinematic lighting"
negative_prompt = "low quality, bad quality"
generator = torch.Generator(device="cuda").manual_seed(30)
image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, generator=generator).to_tuple()
image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt).to_tuple()
out = pipe(
prompt,

View File

@@ -55,7 +55,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = logging.getLogger(__name__)

View File

@@ -56,7 +56,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)

View File

@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@@ -64,7 +64,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)

View File

@@ -51,7 +51,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = logging.getLogger(__name__)

View File

@@ -47,7 +47,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -77,7 +77,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = logging.getLogger(__name__)

View File

@@ -28,7 +28,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.18.0.dev0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -227,7 +227,7 @@ install_requires = [
setup(
name="diffusers",
version="0.17.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.18.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.17.0.dev0"
__version__ = "0.18.0.dev0"
from .configuration_utils import ConfigMixin
from .utils import (

View File

@@ -366,7 +366,7 @@ class UNet2DConditionLoadersMixin:
"""
weight_name = weight_name or deprecate(
"weights_name",
"0.18.0",
"0.20.0",
"`weights_name` is deprecated, please use `weight_name` instead.",
take_from=kwargs,
)

View File

@@ -29,7 +29,7 @@ from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F
deprecate(
"cross_attention",
"0.18.0",
"0.20.0",
"Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
standard_warn=False,
)
@@ -40,55 +40,55 @@ AttnProcessor = AttentionProcessor
class CrossAttention(Attention):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
class CrossAttnProcessor(AttnProcessorRename):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
class LoRACrossAttnProcessor(LoRAAttnProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
class XFormersCrossAttnProcessor(XFormersAttnProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
class SlicedCrossAttnProcessor(SlicedAttnProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)

View File

@@ -1099,13 +1099,6 @@ class DiffusionPipeline(ConfigMixin):
# 8. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
return_cached_folder = kwargs.pop("return_cached_folder", False)
if return_cached_folder:
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.18.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
deprecate("return_cached_folder", "0.18.0", message)
return model, cached_folder
return model
@classmethod
@@ -1254,7 +1247,7 @@ class DiffusionPipeline(ConfigMixin):
# if the whole pipeline is cached we don't have to ping the Hub
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.18.0"):
) >= version.parse("0.20.0"):
warn_deprecated_model_variant(
pretrained_model_name, use_auth_token, variant, revision, model_filenames
)

View File

@@ -280,7 +280,7 @@ def _get_model_file(
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.18.0")
and version.parse(version.parse(__version__).base_version) >= version.parse("0.20.0")
):
try:
model_file = hf_hub_download(

View File

@@ -244,27 +244,6 @@ class DownloadTests(unittest.TestCase):
use_safetensors=True,
)
def test_returned_cached_folder(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
_, local_path = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, return_cached_folder=True
)
pipe_2 = StableDiffusionPipeline.from_pretrained(local_path)
pipe = pipe.to(torch_device)
pipe_2 = pipe_2.to(torch_device)
generator = torch.manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
def test_download_safetensors(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights