1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Wuerstchen] fix compel usage (#4999)

* fix compel usage

* minor changes in documentation

* fix tests

* fix more

* fix more

* typos

* fix tests

* formatting

---------

Co-authored-by: Dominic Rampas <d6582533@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Kashif Rasul
2023-09-13 14:54:59 +02:00
committed by GitHub
parent 0ea51627f1
commit 77373c5eb1
4 changed files with 109 additions and 139 deletions

View File

@@ -8,9 +8,12 @@ The abstract from the paper is:
*We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.*
## Würstchen Overview
Würstchen is a diffusion model, whose text-conditional model works in a highly compressed latent space of images. Why is this important? Compressing data can reduce computational costs for both training and inference by magnitudes. Training on 1024x1024 images is way more expensive than training on 32x32. Usually, other works make use of a relatively small compression, in the range of 4x - 8x spatial compression. Würstchen takes this to an extreme. Through its novel design, we achieve a 42x spatial compression. This was unseen before because common methods fail to faithfully reconstruct detailed images after 16x spatial compression. Würstchen employs a two-stage compression, what we call Stage A and Stage B. Stage A is a VQGAN, and Stage B is a Diffusion Autoencoder (more details can be found in the [paper](https://huggingface.co/papers/2306.00637)). A third model, Stage C, is learned in that highly compressed latent space. This training requires fractions of the compute used for current top-performing models, while also allowing cheaper and faster inference.
## Würstchen v2 comes to Diffusers
After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.
After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competitive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.
- Higher resolution (1024x1024 up to 2048x2048)
- Faster inference
@@ -22,16 +25,16 @@ We are releasing 3 checkpoints for the text-conditional image generation model (
- v2-base
- v2-aesthetic
- v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
- **(default)** v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.
We recommend using v2-interpolated, as it has a nice touch of both photorealism and aesthetics. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.
A comparison can be seen here:
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/2914830f-cbd3-461c-be64-d50734f4b49d" width=500>
## Text-to-Image Generation
For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenCombinedPipeline` and can be used as follows:
For the sake of usability, Würstchen can be used with a single pipeline. This pipeline can be used as follows:
```python
import torch
@@ -85,7 +88,6 @@ decoder_output = decoder_pipeline(
image_embeddings=prior_output.image_embeddings,
prompt=caption,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
guidance_scale=0.0,
output_type="pil",
).images
@@ -95,8 +97,8 @@ decoder_output = decoder_pipeline(
You can make use of `torch.compile` function and gain a speed-up of about 2-3x:
```python
pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True)
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
```
## Limitations

View File

@@ -19,7 +19,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, replace_example_docstring
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel
@@ -72,6 +72,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
width=int(24*10.67)=256 in order to match the training conditions.
"""
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
def __init__(
self,
tokenizer: CLIPTokenizer,
@@ -103,35 +105,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents = latents * scheduler.init_noise_sigma
return latents
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.decoder]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.prior_hook = hook
_, hook = cpu_offload_with_hook(self.vqgan, device, prev_module_hook=self.prior_hook)
self.final_offload_hook = hook
def encode_prompt(
self,
prompt,
@@ -214,48 +187,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
# to avoid doing two forward passes
return text_encoder_hidden_states, uncond_text_encoder_hidden_states
def check_inputs(
self,
image_embeddings,
prompt,
negative_prompt,
num_inference_steps,
do_classifier_free_guidance,
device,
dtype,
):
if not isinstance(prompt, list):
if isinstance(prompt, str):
prompt = [prompt]
else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
if do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
else:
raise TypeError(
f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}."
)
if isinstance(image_embeddings, list):
image_embeddings = torch.cat(image_embeddings, dim=0)
if isinstance(image_embeddings, np.ndarray):
image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype)
if not isinstance(image_embeddings, torch.Tensor):
raise TypeError(
f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}."
)
if not isinstance(num_inference_steps, int):
raise TypeError(
f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\
In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
)
return image_embeddings, prompt, negative_prompt, num_inference_steps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -324,9 +255,35 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 1. Check inputs. Raise error if not correct
image_embeddings, prompt, negative_prompt, num_inference_steps = self.check_inputs(
image_embeddings, prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, device, dtype
)
if not isinstance(prompt, list):
if isinstance(prompt, str):
prompt = [prompt]
else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
if do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
else:
raise TypeError(
f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}."
)
if isinstance(image_embeddings, list):
image_embeddings = torch.cat(image_embeddings, dim=0)
if isinstance(image_embeddings, np.ndarray):
image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype)
if not isinstance(image_embeddings, torch.Tensor):
raise TypeError(
f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}."
)
if not isinstance(num_inference_steps, int):
raise TypeError(
f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\
In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
)
# 2. Encode caption
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
@@ -390,6 +347,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1)
# Offload all models
self.maybe_free_model_hooks()
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}")

View File

@@ -62,7 +62,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
The prior tokenizer to be used for text inputs.
prior_text_encoder (`CLIPTextModel`):
The prior text encoder to be used for text inputs.
prior (`WuerstchenPrior`):
prior_prior (`WuerstchenPrior`):
The prior model to be used for prior pipeline.
prior_scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for prior pipeline.
@@ -119,8 +119,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
self.prior_pipe.enable_model_cpu_offload()
self.decoder_pipe.enable_model_cpu_offload()
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
@@ -144,7 +144,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 512,
prior_num_inference_steps: int = 60,
@@ -249,7 +249,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
outputs = self.decoder_pipe(
image_embeddings=image_embeddings,
prompt=prompt,
prompt=prompt if prompt is not None else "",
num_inference_steps=num_inference_steps,
timesteps=decoder_timesteps,
guidance_scale=decoder_guidance_scale,
@@ -258,4 +258,5 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
output_type=output_type,
return_dict=return_dict,
)
return outputs

View File

@@ -23,8 +23,6 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import (
BaseOutput,
is_accelerate_available,
is_accelerate_version,
logging,
replace_example_docstring,
)
@@ -86,6 +84,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
A scheduler to be used in combination with `prior` to generate image embedding.
"""
model_cpu_offload_seq = "text_encoder->prior"
def __init__(
self,
tokenizer: CLIPTokenizer,
@@ -107,35 +107,6 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple
)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.prior_hook = hook
_, hook = cpu_offload_with_hook(self.prior, device, prev_module_hook=self.prior_hook)
self.final_offload_hook = hook
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
@@ -249,22 +220,34 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
negative_prompt,
num_inference_steps,
do_classifier_free_guidance,
batch_size,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if not isinstance(prompt, list):
if isinstance(prompt, str):
prompt = [prompt]
else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
else:
raise TypeError(
f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if not isinstance(num_inference_steps, int):
raise TypeError(
@@ -272,10 +255,6 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
)
batch_size = len(prompt) if isinstance(prompt, list) else 1
return prompt, negative_prompt, num_inference_steps, batch_size
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -361,11 +340,36 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
# 0. Define commonly used variables
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
batch_size = len(prompt) if isinstance(prompt, list) else 1
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 1. Check inputs. Raise error if not correct
prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs(
prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size
if prompt is not None and not isinstance(prompt, list):
if isinstance(prompt, str):
prompt = [prompt]
else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
if do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
else:
raise TypeError(
f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}."
)
self.check_inputs(
prompt,
negative_prompt,
num_inference_steps,
do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 2. Encode caption
@@ -437,6 +441,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
# 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std
# Offload all models
self.maybe_free_model_hooks()
if output_type == "np":
latents = latents.cpu().numpy()