1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Unclip] Make sure text_embeddings & image_embeddings can directly be passed to enable interpolation tasks. (#1858)

* [Unclip] Make sure latents can be reused

* allow one to directly pass embeddings

* up

* make unclip for text work

* finish allowing to pass embeddings

* correct more

* make style
This commit is contained in:
Patrick von Platen
2022-12-30 12:18:19 +01:00
committed by GitHub
parent 29b2c93c90
commit b28ab30215
4 changed files with 270 additions and 79 deletions

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import inspect
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import functional as F
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
@@ -117,31 +118,44 @@ class UnCLIPPipeline(DiffusionPipeline):
latents = latents * scheduler.init_noise_sigma
return latents
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None,
):
if text_model_output is None:
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)
text_encoder_output = self.text_encoder(text_input_ids.to(device))
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
@@ -150,11 +164,10 @@ class UnCLIPPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
@@ -235,7 +248,7 @@ class UnCLIPPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
prior_num_inference_steps: int = 25,
decoder_num_inference_steps: int = 25,
@@ -244,6 +257,8 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_latents: Optional[torch.FloatTensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None,
prior_guidance_scale: float = 4.0,
decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil",
@@ -254,7 +269,8 @@ class UnCLIPPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
The prompt or prompts to guide the image generation. This can only be left undefined if
`text_model_output` and `text_attention_mask` is passed.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prior_num_inference_steps (`int`, *optional*, defaults to 25):
@@ -287,18 +303,29 @@ class UnCLIPPipeline(DiffusionPipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
text_model_output (`CLIPTextModelOutput`, *optional*):
Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
can be passed for tasks like text embedding interpolations. Make sure to also pass
`text_attention_mask` in this case. `prompt` can the be left to `None`.
text_attention_mask (`torch.Tensor`, *optional*):
Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
masks are necessary when passing `text_model_output`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
if prompt is not None:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
batch_size = text_model_output[0].shape[0]
device = self._execution_device
batch_size = batch_size * num_images_per_prompt
@@ -306,7 +333,7 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
# prior
@@ -315,6 +342,7 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_timesteps_tensor = self.prior_scheduler.timesteps
embedding_dim = self.prior.config.embedding_dim
prior_latents = self.prepare_latents(
(batch_size, embedding_dim),
text_embeddings.dtype,
@@ -378,6 +406,7 @@ class UnCLIPPipeline(DiffusionPipeline):
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
@@ -430,6 +459,7 @@ class UnCLIPPipeline(DiffusionPipeline):
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,

View File

@@ -126,7 +126,6 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
latents = latents * scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1
@@ -139,15 +138,6 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
)
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds
@@ -199,14 +189,15 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
return text_embeddings, text_encoder_hidden_states, text_mask
def _encode_image(self, image, device, num_images_per_prompt):
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
if image_embeddings is None:
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
@@ -258,13 +249,14 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
num_images_per_prompt: int = 1,
decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7,
generator: Optional[torch.Generator] = None,
decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None,
image_embeddings: Optional[torch.Tensor] = None,
decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
@@ -277,7 +269,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor`.
`CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
@@ -299,18 +291,24 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
image_embeddings (`torch.Tensor`, *optional*):
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
can be passed for tasks like image interpolations. `image` can the be left to `None`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
"""
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
if image is not None:
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
else:
batch_size = image.shape[0]
batch_size = image_embeddings.shape[0]
prompt = [""] * batch_size
@@ -324,10 +322,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
prompt, device, num_images_per_prompt, do_classifier_free_guidance
)
image_embeddings = self._encode_image(image, device, num_images_per_prompt)
image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings)
# decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
text_embeddings=text_embeddings,
@@ -343,14 +340,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
device,
generator,
decoder_latents,
self.decoder_scheduler,
)
if decoder_latents is None:
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
device,
generator,
decoder_latents,
self.decoder_scheduler,
)
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
@@ -395,14 +394,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,
device,
generator,
super_res_latents,
self.super_res_scheduler,
)
if super_res_latents is None:
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,
device,
generator,
super_res_latents,
self.super_res_scheduler,
)
interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters:

View File

@@ -248,6 +248,120 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_unclip_passed_text_embed(self):
device = torch.device("cpu")
class DummyScheduler:
init_noise_sigma = 1
prior = self.dummy_prior
decoder = self.dummy_decoder
text_proj = self.dummy_text_proj
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
super_res_first = self.dummy_super_res_first
super_res_last = self.dummy_super_res_last
prior_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="sample",
num_train_timesteps=1000,
clip_sample_range=5.0,
)
decoder_scheduler = UnCLIPScheduler(
variance_type="learned_range",
prediction_type="epsilon",
num_train_timesteps=1000,
)
super_res_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="epsilon",
num_train_timesteps=1000,
)
pipe = UnCLIPPipeline(
prior=prior,
decoder=decoder,
text_proj=text_proj,
text_encoder=text_encoder,
tokenizer=tokenizer,
super_res_first=super_res_first,
super_res_last=super_res_last,
prior_scheduler=prior_scheduler,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
pipe = pipe.to(device)
generator = torch.Generator(device=device).manual_seed(0)
dtype = prior.dtype
batch_size = 1
shape = (batch_size, prior.config.embedding_dim)
prior_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size)
decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (
batch_size,
super_res_first.in_channels // 2,
super_res_first.sample_size,
super_res_first.sample_size,
)
super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
pipe.set_progress_bar_config(disable=None)
prompt = "this is a prompt example"
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(
[prompt],
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
prior_latents=prior_latents,
decoder_latents=decoder_latents,
super_res_latents=super_res_latents,
output_type="np",
)
image = output.images
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
)
text_model_output = text_encoder(text_inputs.input_ids)
text_attention_mask = text_inputs.attention_mask
generator = torch.Generator(device=device).manual_seed(0)
image_from_text = pipe(
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
prior_latents=prior_latents,
decoder_latents=decoder_latents,
super_res_latents=super_res_latents,
text_model_output=text_model_output,
text_attention_mask=text_attention_mask,
output_type="np",
)[0]
# make sure passing text embeddings manually is identical
assert np.abs(image - image_from_text).max() < 1e-4
@slow
@require_torch_gpu

View File

@@ -407,6 +407,55 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_unclip_passed_image_embed(self):
device = torch.device("cpu")
seed = 0
class DummyScheduler:
init_noise_sigma = 1
pipe = self.get_pipeline(device)
generator = torch.Generator(device=device).manual_seed(0)
dtype = pipe.decoder.dtype
batch_size = 1
shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size)
decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (
batch_size,
pipe.super_res_first.in_channels // 2,
pipe.super_res_first.sample_size,
pipe.super_res_first.sample_size,
)
super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
pipeline_inputs = self.get_pipeline_inputs(device, seed)
img_out_1 = pipe(
**pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
).images
pipeline_inputs = self.get_pipeline_inputs(device, seed)
# Don't pass image, instead pass embedding
image = pipeline_inputs.pop("image")
image_embeddings = pipe.image_encoder(image).image_embeds
img_out_2 = pipe(
**pipeline_inputs,
decoder_latents=decoder_latents,
super_res_latents=super_res_latents,
image_embeddings=image_embeddings,
).images
# make sure passing text embeddings manually is identical
assert np.abs(img_out_1 - img_out_2).max() < 1e-4
@slow
@require_torch_gpu
@@ -426,11 +475,10 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
"/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
)
pipeline = UnCLIPImageVariationPipeline.from_pretrained(
"fusing/karlo-image-variations-diffusers", torch_dtype=torch.float16
)
pipeline = UnCLIPImageVariationPipeline.from_pretrained("fusing/karlo-image-variations-diffusers")
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
pipeline.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipeline(
@@ -442,7 +490,5 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
image = output.images[0]
np.save("./karlo_v1_alpha_cat_variation_fp16.npy", image)
assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2
assert np.abs(expected_image - image).max() < 5e-2