mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Stable Diffusion] Add components function (#889)
* [Stable Diffusion] Add components function * uP
This commit is contained in:
committed by
GitHub
parent
2a0c823527
commit
83f8a5ff70
@@ -32,6 +32,9 @@ Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrain
|
||||
[[autodoc]] DiffusionPipeline
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
- to
|
||||
- device
|
||||
- components
|
||||
|
||||
## ImagePipelineOutput
|
||||
By default diffusion pipelines return an object of class
|
||||
|
||||
@@ -17,6 +17,26 @@ For more details about how Stable Diffusion works and how it differs from the ba
|
||||
| [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest)
|
||||
| [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental** – *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon
|
||||
|
||||
## Tips
|
||||
|
||||
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
|
||||
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
|
||||
- Make use of the `components` functionality to instantiate all components in the most memory-efficient way:
|
||||
|
||||
```python
|
||||
>>> from diffusers import (
|
||||
... StableDiffusionPipeline,
|
||||
... StableDiffusionImg2ImgPipeline,
|
||||
... StableDiffusionInpaintPipeline,
|
||||
... )
|
||||
|
||||
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
|
||||
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
|
||||
|
||||
>>> # now you can use img2text(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline
|
||||
```
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -561,6 +561,41 @@ class DiffusionPipeline(ConfigMixin):
|
||||
model = pipeline_class(**init_kwargs)
|
||||
return model
|
||||
|
||||
@property
|
||||
def components(self) -> Dict[str, Any]:
|
||||
r"""
|
||||
|
||||
The `self.compenents` property can be useful to run different pipelines with the same weights and
|
||||
configurations to not have to re-allocate memory.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import (
|
||||
... StableDiffusionPipeline,
|
||||
... StableDiffusionImg2ImgPipeline,
|
||||
... StableDiffusionInpaintPipeline,
|
||||
... )
|
||||
|
||||
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
|
||||
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
|
||||
```
|
||||
|
||||
Returns:
|
||||
A dictionaly containing all the modules needed to initialize the pipleline.
|
||||
"""
|
||||
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
|
||||
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
|
||||
|
||||
if set(components.keys()) != expected_modules:
|
||||
raise ValueError(
|
||||
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
||||
f" {expected_modules} to be defined, but {components} are defined."
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
|
||||
@@ -1391,6 +1391,59 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
|
||||
def test_components(self):
|
||||
"""Test that components property works correctly"""
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
inpaint = StableDiffusionInpaintPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components)
|
||||
text2img = StableDiffusionPipeline(**inpaint.components)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image_inpaint = inpaint(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
).images
|
||||
image_img2img = img2img(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
).images
|
||||
image_text2img = text2img(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
assert image_inpaint.shape == (1, 32, 32, 3)
|
||||
assert image_img2img.shape == (1, 32, 32, 3)
|
||||
assert image_text2img.shape == (1, 128, 128, 3)
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
|
||||
Reference in New Issue
Block a user