mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
vq diffusion classifier free sampling (#1294)
* vq diffusion classifier free sampling * correct * uP Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -39,8 +39,8 @@ import torch
|
||||
|
||||
import yaml
|
||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||
from diffusers.models.attention import Transformer2DModel
|
||||
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from yaml.loader import FullLoader
|
||||
|
||||
@@ -826,6 +826,20 @@ if __name__ == "__main__":
|
||||
transformer_model, checkpoint
|
||||
)
|
||||
|
||||
# classifier free sampling embeddings interlude
|
||||
|
||||
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
|
||||
# model, so we pull them off the checkpoint before the checkpoint is deleted.
|
||||
|
||||
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf
|
||||
|
||||
if learnable_classifier_free_sampling_embeddings:
|
||||
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
|
||||
else:
|
||||
learned_classifier_free_sampling_embeddings_embeddings = None
|
||||
|
||||
# done classifier free sampling embeddings interlude
|
||||
|
||||
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
|
||||
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
|
||||
del diffusers_transformer_checkpoint
|
||||
@@ -871,6 +885,31 @@ if __name__ == "__main__":
|
||||
|
||||
# done scheduler
|
||||
|
||||
# learned classifier free sampling embeddings
|
||||
|
||||
with init_empty_weights():
|
||||
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
|
||||
learnable_classifier_free_sampling_embeddings,
|
||||
hidden_size=text_encoder_model.config.hidden_size,
|
||||
length=tokenizer_model.model_max_length,
|
||||
)
|
||||
|
||||
learned_classifier_free_sampling_checkpoint = {
|
||||
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
|
||||
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
|
||||
del learned_classifier_free_sampling_checkpoint
|
||||
del learned_classifier_free_sampling_embeddings_embeddings
|
||||
load_checkpoint_and_dispatch(
|
||||
learned_classifier_free_sampling_embeddings_model,
|
||||
learned_classifier_free_sampling_checkpoint_file.name,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# done learned classifier free sampling embeddings
|
||||
|
||||
print(f"saving VQ diffusion model, path: {args.dump_path}")
|
||||
|
||||
pipe = VQDiffusionPipeline(
|
||||
@@ -878,6 +917,7 @@ if __name__ == "__main__":
|
||||
transformer=transformer_model,
|
||||
tokenizer=tokenizer_model,
|
||||
text_encoder=text_encoder_model,
|
||||
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
|
||||
scheduler=scheduler_model,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
from .pipeline_vq_diffusion import VQDiffusionPipeline
|
||||
from ...utils import is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline
|
||||
|
||||
@@ -20,6 +20,8 @@ from diffusers import Transformer2DModel, VQModel
|
||||
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...modeling_utils import ModelMixin
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...utils import logging
|
||||
|
||||
@@ -27,6 +29,28 @@ from ...utils import logging
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Utility class for storing learned text embeddings for classifier free sampling
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
self.learnable = learnable
|
||||
|
||||
if self.learnable:
|
||||
assert hidden_size is not None, "learnable=True requires `hidden_size` to be set"
|
||||
assert length is not None, "learnable=True requires `length` to be set"
|
||||
|
||||
embeddings = torch.zeros(length, hidden_size)
|
||||
else:
|
||||
embeddings = None
|
||||
|
||||
self.embeddings = torch.nn.Parameter(embeddings)
|
||||
|
||||
|
||||
class VQDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using VQ Diffusion
|
||||
@@ -55,6 +79,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel
|
||||
tokenizer: CLIPTokenizer
|
||||
transformer: Transformer2DModel
|
||||
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
|
||||
scheduler: VQDiffusionScheduler
|
||||
|
||||
def __init__(
|
||||
@@ -64,6 +89,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
tokenizer: CLIPTokenizer,
|
||||
transformer: Transformer2DModel,
|
||||
scheduler: VQDiffusionScheduler,
|
||||
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -73,13 +99,78 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
|
||||
)
|
||||
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||
|
||||
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
|
||||
# While CLIP does normalize the pooled output of the text transformer when combining
|
||||
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
|
||||
#
|
||||
# CLIP normalizing the pooled output.
|
||||
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
|
||||
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.learned_classifier_free_sampling_embeddings.learnable:
|
||||
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
|
||||
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
else:
|
||||
uncond_tokens = [""] * batch_size
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
# See comment for normalizing text embeddings
|
||||
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 5.0,
|
||||
truncation_rate: float = 1.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -98,6 +189,12 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
|
||||
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
|
||||
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
|
||||
@@ -137,6 +234,10 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
@@ -145,35 +246,6 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||
|
||||
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
|
||||
# While CLIP does normalize the pooled output of the text transformer when combining
|
||||
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
|
||||
#
|
||||
# CLIP normalizing the pooled output.
|
||||
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
|
||||
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# get the initial completely masked latents unless the user supplied it
|
||||
|
||||
latents_shape = (batch_size, self.transformer.num_latent_pixels)
|
||||
@@ -198,9 +270,19 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||
sample = latents
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the sample if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
|
||||
|
||||
# predict the un-noised image
|
||||
# model_output == `log_p_x_0`
|
||||
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
|
||||
model_output = self.transformer(
|
||||
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
model_output_uncond, model_output_text = model_output.chunk(2)
|
||||
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
|
||||
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)
|
||||
|
||||
model_output = self.truncate(model_output, truncation_rate)
|
||||
|
||||
|
||||
@@ -20,7 +20,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||
from diffusers.utils import load_image, slow, torch_device
|
||||
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
|
||||
from diffusers.utils import load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -45,6 +46,10 @@ class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def num_embeds_ada_norm(self):
|
||||
return 12
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def dummy_vqvae(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -71,7 +76,7 @@ class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
@@ -111,9 +116,15 @@ class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
tokenizer = self.dummy_tokenizer
|
||||
transformer = self.dummy_transformer
|
||||
scheduler = VQDiffusionScheduler(self.num_embed)
|
||||
learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(learnable=False)
|
||||
|
||||
pipe = VQDiffusionPipeline(
|
||||
vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler
|
||||
vqvae=vqvae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
|
||||
)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -139,6 +150,50 @@ class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_vq_diffusion_classifier_free_sampling(self):
|
||||
device = "cpu"
|
||||
|
||||
vqvae = self.dummy_vqvae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
transformer = self.dummy_transformer
|
||||
scheduler = VQDiffusionScheduler(self.num_embed)
|
||||
learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(
|
||||
learnable=True, hidden_size=self.text_embedder_hidden_size, length=tokenizer.model_max_length
|
||||
)
|
||||
|
||||
pipe = VQDiffusionPipeline(
|
||||
vqvae=vqvae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
|
||||
)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "teddy bear playing in the pool"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = pipe(
|
||||
[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 24, 24, 3)
|
||||
|
||||
expected_slice = np.array([0.6647, 0.6531, 0.5303, 0.5891, 0.5726, 0.4439, 0.6304, 0.5564, 0.4912])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -149,12 +204,11 @@ class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_vq_diffusion(self):
|
||||
expected_image = load_image(
|
||||
def test_vq_diffusion_classifier_free_sampling(self):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/vq_diffusion/teddy_bear_pool.png"
|
||||
"/vq_diffusion/teddy_bear_pool_classifier_free_sampling.npy"
|
||||
)
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq")
|
||||
pipeline = pipeline.to(torch_device)
|
||||
@@ -163,7 +217,6 @@ class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipeline(
|
||||
"teddy bear playing in the pool",
|
||||
truncation_rate=0.86,
|
||||
num_images_per_prompt=1,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
|
||||
Reference in New Issue
Block a user