mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Unclip] Make sure text_embeddings & image_embeddings can directly be passed to enable interpolation tasks. (#1858)
* [Unclip] Make sure latents can be reused * allow one to directly pass embeddings * up * make unclip for text work * finish allowing to pass embeddings * correct more * make style
This commit is contained in:
committed by
GitHub
parent
29b2c93c90
commit
b28ab30215
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
||||
|
||||
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||
@@ -117,31 +118,44 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
|
||||
text_attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if text_model_output is None:
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_embeddings = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
text_embeddings = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
|
||||
else:
|
||||
batch_size = text_model_output[0].shape[0]
|
||||
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
|
||||
text_mask = text_attention_mask
|
||||
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
@@ -150,11 +164,10 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens = [""] * batch_size
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -235,7 +248,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prior_num_inference_steps: int = 25,
|
||||
decoder_num_inference_steps: int = 25,
|
||||
@@ -244,6 +257,8 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
prior_latents: Optional[torch.FloatTensor] = None,
|
||||
decoder_latents: Optional[torch.FloatTensor] = None,
|
||||
super_res_latents: Optional[torch.FloatTensor] = None,
|
||||
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
|
||||
text_attention_mask: Optional[torch.Tensor] = None,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
decoder_guidance_scale: float = 8.0,
|
||||
output_type: Optional[str] = "pil",
|
||||
@@ -254,7 +269,8 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
The prompt or prompts to guide the image generation. This can only be left undefined if
|
||||
`text_model_output` and `text_attention_mask` is passed.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prior_num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
@@ -287,18 +303,29 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
text_model_output (`CLIPTextModelOutput`, *optional*):
|
||||
Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
|
||||
can be passed for tasks like text embedding interpolations. Make sure to also pass
|
||||
`text_attention_mask` in this case. `prompt` can the be left to `None`.
|
||||
text_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
|
||||
masks are necessary when passing `text_model_output`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
if prompt is not None:
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
batch_size = text_model_output[0].shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
@@ -306,7 +333,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
|
||||
|
||||
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
|
||||
)
|
||||
|
||||
# prior
|
||||
@@ -315,6 +342,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
prior_timesteps_tensor = self.prior_scheduler.timesteps
|
||||
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
|
||||
prior_latents = self.prepare_latents(
|
||||
(batch_size, embedding_dim),
|
||||
text_embeddings.dtype,
|
||||
@@ -378,6 +406,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
@@ -430,6 +459,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
image_small.dtype,
|
||||
|
||||
@@ -126,7 +126,6 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
@@ -139,15 +138,6 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
text_embeddings = text_encoder_output.text_embeds
|
||||
@@ -199,14 +189,15 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
|
||||
return text_embeddings, text_encoder_hidden_states, text_mask
|
||||
|
||||
def _encode_image(self, image, device, num_images_per_prompt):
|
||||
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
||||
if image_embeddings is None:
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeddings = self.image_encoder(image).image_embeds
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeddings = self.image_encoder(image).image_embeds
|
||||
|
||||
image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
@@ -258,13 +249,14 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
||||
image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
decoder_num_inference_steps: int = 25,
|
||||
super_res_num_inference_steps: int = 7,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
decoder_latents: Optional[torch.FloatTensor] = None,
|
||||
super_res_latents: Optional[torch.FloatTensor] = None,
|
||||
image_embeddings: Optional[torch.Tensor] = None,
|
||||
decoder_guidance_scale: float = 8.0,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -277,7 +269,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
|
||||
configuration of
|
||||
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
|
||||
`CLIPFeatureExtractor`.
|
||||
`CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
@@ -299,18 +291,24 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
image_embeddings (`torch.Tensor`, *optional*):
|
||||
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
|
||||
can be passed for tasks like image interpolations. `image` can the be left to `None`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
"""
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(image, list):
|
||||
batch_size = len(image)
|
||||
if image is not None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(image, list):
|
||||
batch_size = len(image)
|
||||
else:
|
||||
batch_size = image.shape[0]
|
||||
else:
|
||||
batch_size = image.shape[0]
|
||||
batch_size = image_embeddings.shape[0]
|
||||
|
||||
prompt = [""] * batch_size
|
||||
|
||||
@@ -324,10 +322,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
|
||||
image_embeddings = self._encode_image(image, device, num_images_per_prompt)
|
||||
image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings)
|
||||
|
||||
# decoder
|
||||
|
||||
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
|
||||
image_embeddings=image_embeddings,
|
||||
text_embeddings=text_embeddings,
|
||||
@@ -343,14 +340,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
device,
|
||||
generator,
|
||||
decoder_latents,
|
||||
self.decoder_scheduler,
|
||||
)
|
||||
|
||||
if decoder_latents is None:
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
device,
|
||||
generator,
|
||||
decoder_latents,
|
||||
self.decoder_scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
@@ -395,14 +394,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
image_small.dtype,
|
||||
device,
|
||||
generator,
|
||||
super_res_latents,
|
||||
self.super_res_scheduler,
|
||||
)
|
||||
|
||||
if super_res_latents is None:
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
image_small.dtype,
|
||||
device,
|
||||
generator,
|
||||
super_res_latents,
|
||||
self.super_res_scheduler,
|
||||
)
|
||||
|
||||
interpolate_antialias = {}
|
||||
if "antialias" in inspect.signature(F.interpolate).parameters:
|
||||
|
||||
@@ -248,6 +248,120 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_unclip_passed_text_embed(self):
|
||||
device = torch.device("cpu")
|
||||
|
||||
class DummyScheduler:
|
||||
init_noise_sigma = 1
|
||||
|
||||
prior = self.dummy_prior
|
||||
decoder = self.dummy_decoder
|
||||
text_proj = self.dummy_text_proj
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
super_res_first = self.dummy_super_res_first
|
||||
super_res_last = self.dummy_super_res_last
|
||||
|
||||
prior_scheduler = UnCLIPScheduler(
|
||||
variance_type="fixed_small_log",
|
||||
prediction_type="sample",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample_range=5.0,
|
||||
)
|
||||
|
||||
decoder_scheduler = UnCLIPScheduler(
|
||||
variance_type="learned_range",
|
||||
prediction_type="epsilon",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
super_res_scheduler = UnCLIPScheduler(
|
||||
variance_type="fixed_small_log",
|
||||
prediction_type="epsilon",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
pipe = UnCLIPPipeline(
|
||||
prior=prior,
|
||||
decoder=decoder,
|
||||
text_proj=text_proj,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
super_res_first=super_res_first,
|
||||
super_res_last=super_res_last,
|
||||
prior_scheduler=prior_scheduler,
|
||||
decoder_scheduler=decoder_scheduler,
|
||||
super_res_scheduler=super_res_scheduler,
|
||||
)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
dtype = prior.dtype
|
||||
batch_size = 1
|
||||
|
||||
shape = (batch_size, prior.config.embedding_dim)
|
||||
prior_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size)
|
||||
decoder_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
super_res_first.in_channels // 2,
|
||||
super_res_first.sample_size,
|
||||
super_res_first.sample_size,
|
||||
)
|
||||
super_res_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "this is a prompt example"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
prior_num_inference_steps=2,
|
||||
decoder_num_inference_steps=2,
|
||||
super_res_num_inference_steps=2,
|
||||
prior_latents=prior_latents,
|
||||
decoder_latents=decoder_latents,
|
||||
super_res_latents=super_res_latents,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_model_output = text_encoder(text_inputs.input_ids)
|
||||
text_attention_mask = text_inputs.attention_mask
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_text = pipe(
|
||||
generator=generator,
|
||||
prior_num_inference_steps=2,
|
||||
decoder_num_inference_steps=2,
|
||||
super_res_num_inference_steps=2,
|
||||
prior_latents=prior_latents,
|
||||
decoder_latents=decoder_latents,
|
||||
super_res_latents=super_res_latents,
|
||||
text_model_output=text_model_output,
|
||||
text_attention_mask=text_attention_mask,
|
||||
output_type="np",
|
||||
)[0]
|
||||
|
||||
# make sure passing text embeddings manually is identical
|
||||
assert np.abs(image - image_from_text).max() < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -407,6 +407,55 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_unclip_passed_image_embed(self):
|
||||
device = torch.device("cpu")
|
||||
seed = 0
|
||||
|
||||
class DummyScheduler:
|
||||
init_noise_sigma = 1
|
||||
|
||||
pipe = self.get_pipeline(device)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
dtype = pipe.decoder.dtype
|
||||
batch_size = 1
|
||||
|
||||
shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size)
|
||||
decoder_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
pipe.super_res_first.in_channels // 2,
|
||||
pipe.super_res_first.sample_size,
|
||||
pipe.super_res_first.sample_size,
|
||||
)
|
||||
super_res_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
pipeline_inputs = self.get_pipeline_inputs(device, seed)
|
||||
|
||||
img_out_1 = pipe(
|
||||
**pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
|
||||
).images
|
||||
|
||||
pipeline_inputs = self.get_pipeline_inputs(device, seed)
|
||||
# Don't pass image, instead pass embedding
|
||||
image = pipeline_inputs.pop("image")
|
||||
image_embeddings = pipe.image_encoder(image).image_embeds
|
||||
|
||||
img_out_2 = pipe(
|
||||
**pipeline_inputs,
|
||||
decoder_latents=decoder_latents,
|
||||
super_res_latents=super_res_latents,
|
||||
image_embeddings=image_embeddings,
|
||||
).images
|
||||
|
||||
# make sure passing text embeddings manually is identical
|
||||
assert np.abs(img_out_1 - img_out_2).max() < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -426,11 +475,10 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
"/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
|
||||
)
|
||||
|
||||
pipeline = UnCLIPImageVariationPipeline.from_pretrained(
|
||||
"fusing/karlo-image-variations-diffusers", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = UnCLIPImageVariationPipeline.from_pretrained("fusing/karlo-image-variations-diffusers")
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
pipeline.enable_sequential_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipeline(
|
||||
@@ -442,7 +490,5 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
np.save("./karlo_v1_alpha_cat_variation_fp16.npy", image)
|
||||
|
||||
assert image.shape == (256, 256, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
assert np.abs(expected_image - image).max() < 5e-2
|
||||
|
||||
Reference in New Issue
Block a user