diff --git a/scripts/convert_unidiffuser_to_diffusers.py b/scripts/convert_unidiffuser_to_diffusers.py index 891d289d8c..4c38172754 100644 --- a/scripts/convert_unidiffuser_to_diffusers.py +++ b/scripts/convert_unidiffuser_to_diffusers.py @@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): return mapping +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + # Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint # config.num_head_channels => num_head_channels def assign_to_checkpoint( @@ -104,8 +117,9 @@ def assign_to_checkpoint( ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new - checkpoint. + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. """ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." @@ -143,25 +157,16 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] -# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - def create_vae_diffusers_config(config_type): # Hardcoded for now if args.config_type == "test": @@ -339,7 +344,7 @@ def create_text_decoder_config_big(): return text_decoder_config -# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint +# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1): """ Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL. @@ -674,6 +679,11 @@ if __name__ == "__main__": type=int, help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.", ) + parser.add_argument( + "--safe_serialization", + action="store_true", + help="Whether to use safetensors/safe seialization when saving the pipeline.", + ) args = parser.parse_args() @@ -766,11 +776,11 @@ if __name__ == "__main__": vae=vae, text_encoder=text_encoder, image_encoder=image_encoder, - image_processor=image_processor, + clip_image_processor=image_processor, clip_tokenizer=clip_tokenizer, text_decoder=text_decoder, text_tokenizer=text_tokenizer, unet=unet, scheduler=scheduler, ) - pipeline.save_pretrained(args.pipeline_output_path) + pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization) diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 6e9c9f96b0..d847831efc 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -13,9 +13,12 @@ from transformers import ( GPT2Tokenizer, ) +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging +from ...utils import deprecate, logging from ...utils.outputs import BaseOutput from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -26,30 +29,6 @@ from .modeling_uvit import UniDiffuserModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def preprocess(image): - deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" - deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - # New BaseOutput child class for joint image-text output @dataclass class ImageTextPipelineOutput(BaseOutput): @@ -111,7 +90,7 @@ class UniDiffuserPipeline(DiffusionPipeline): vae: AutoencoderKL, text_encoder: CLIPTextModel, image_encoder: CLIPVisionModelWithProjection, - image_processor: CLIPImageProcessor, + clip_image_processor: CLIPImageProcessor, clip_tokenizer: CLIPTokenizer, text_decoder: UniDiffuserTextDecoder, text_tokenizer: GPT2Tokenizer, @@ -130,7 +109,7 @@ class UniDiffuserPipeline(DiffusionPipeline): vae=vae, text_encoder=text_encoder, image_encoder=image_encoder, - image_processor=image_processor, + clip_image_processor=clip_image_processor, clip_tokenizer=clip_tokenizer, text_decoder=text_decoder, text_tokenizer=text_tokenizer, @@ -139,6 +118,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.num_channels_latents = vae.config.latent_channels self.text_encoder_seq_len = text_encoder.config.max_position_embeddings @@ -155,43 +135,38 @@ class UniDiffuserPipeline(DiffusionPipeline): # TODO: handle safety checking? self.safety_checker = None - # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list - def enable_model_cpu_offload(self, gpu_id=0): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + self.vae.enable_slicing() - device = torch.device(f"cuda:{gpu_id}") + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() - hook = None - for cpu_offloaded_model in [ - self.text_encoder.text_model, - self.image_encoder, - self.unet, - self.vae, - self.text_decoder.encode_prefix, - self.text_decoder.decode_prefix, - self.text_decoder, - ]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -370,8 +345,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ) return batch_size, multiplier - # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - # self.tokenizer => self.clip_tokenizer + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, @@ -381,6 +355,41 @@ class UniDiffuserPipeline(DiffusionPipeline): negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with self.tokenizer->self.clip_tokenizer + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -396,8 +405,8 @@ class UniDiffuserPipeline(DiffusionPipeline): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -405,7 +414,20 @@ class UniDiffuserPipeline(DiffusionPipeline): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -414,6 +436,10 @@ class UniDiffuserPipeline(DiffusionPipeline): batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.clip_tokenizer) + text_inputs = self.clip_tokenizer( prompt, padding="max_length", @@ -440,13 +466,31 @@ class UniDiffuserPipeline(DiffusionPipeline): else: attention_mask = None - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -458,7 +502,7 @@ class UniDiffuserPipeline(DiffusionPipeline): uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -474,6 +518,10 @@ class UniDiffuserPipeline(DiffusionPipeline): else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.clip_tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.clip_tokenizer( uncond_tokens, @@ -498,17 +546,12 @@ class UniDiffuserPipeline(DiffusionPipeline): # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents # Add num_prompts_per_image argument, sample from autoencoder moment distribution @@ -587,7 +630,7 @@ class UniDiffuserPipeline(DiffusionPipeline): f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - preprocessed_image = self.image_processor.preprocess( + preprocessed_image = self.clip_image_processor.preprocess( image, return_tensors="pt", ) @@ -628,17 +671,6 @@ class UniDiffuserPipeline(DiffusionPipeline): return image_latents - # Note that the CLIP latents are not decoded for image generation. - # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - # Rename: decode_latents -> decode_image_latents - def decode_image_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - def prepare_text_latents( self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None ): @@ -720,6 +752,17 @@ class UniDiffuserPipeline(DiffusionPipeline): latents = latents * self.scheduler.init_noise_sigma return latents + def decode_text_latents(self, text_latents, device): + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + generated_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + return generated_text + def _split(self, x, height, width): r""" Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) @@ -1181,7 +1224,7 @@ class UniDiffuserPipeline(DiffusionPipeline): # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. # Note that this differs from the formulation in the unidiffusers paper! - # do_classifier_free_guidance = guidance_scale > 1.0 + do_classifier_free_guidance = guidance_scale > 1.0 # check if scheduler is in sigmas space # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") @@ -1194,15 +1237,18 @@ class UniDiffuserPipeline(DiffusionPipeline): if mode in ["text2img"]: # 3.1. Encode input prompt, if available assert prompt is not None or prompt_embeds is not None - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=multiplier, - do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now + do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + + # if do_classifier_free_guidance: + # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) else: # 3.2. Prepare text latent variables, if input not available prompt_embeds = self.prepare_text_latents( @@ -1224,7 +1270,7 @@ class UniDiffuserPipeline(DiffusionPipeline): # 4.1. Encode images, if available assert image is not None, "`img2text` requires a conditioning image" # Encode image using VAE - image_vae = preprocess(image) + image_vae = self.image_processor.preprocess(image) height, width = image_vae.shape[-2:] image_vae_latents = self.encode_image_vae_latents( image=image_vae, @@ -1324,48 +1370,42 @@ class UniDiffuserPipeline(DiffusionPipeline): callback(i, t, latents) # 9. Post-processing - gen_image = None - gen_text = None + image = None + text = None if mode == "joint": image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) - # Map latent VAE image back to pixel space - gen_image = self.decode_image_latents(image_vae_latents) + if not output_type == "latent": + # Map latent VAE image back to pixel space + image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = image_vae_latents - # Generate text using the text decoder - output_token_list, seq_lengths = self.text_decoder.generate_captions( - text_latents, self.text_tokenizer.eos_token_id, device=device - ) - output_list = output_token_list.cpu().numpy() - gen_text = [ - self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) - for output, length in zip(output_list, seq_lengths) - ] + text = self.decode_text_latents(text_latents, device) elif mode in ["text2img", "img"]: image_vae_latents, image_clip_latents = self._split(latents, height, width) - gen_image = self.decode_image_latents(image_vae_latents) + + if not output_type == "latent": + # Map latent VAE image back to pixel space + image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = image_vae_latents elif mode in ["img2text", "text"]: text_latents = latents - output_token_list, seq_lengths = self.text_decoder.generate_captions( - text_latents, self.text_tokenizer.eos_token_id, device=device - ) - output_list = output_token_list.cpu().numpy() - gen_text = [ - self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) - for output, length in zip(output_list, seq_lengths) - ] + text = self.decode_text_latents(text_latents, device) self.maybe_free_model_hooks() - # 10. Convert to PIL - if output_type == "pil" and gen_image is not None: - gen_image = self.numpy_to_pil(gen_image) + # 10. Postprocess the image, if necessary + if image is not None: + do_denormalize = [True] * image.shape[0] + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (gen_image, gen_text) + return (image, text) - return ImageTextPipelineOutput(images=gen_image, text=gen_text) + return ImageTextPipelineOutput(images=image, text=text) diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 01eee68c76..ba8026db61 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -1,5 +1,6 @@ import gc import random +import traceback import unittest import numpy as np @@ -20,17 +21,70 @@ from diffusers import ( UniDiffuserPipeline, UniDiffuserTextDecoder, ) -from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + load_image, + nightly, + require_torch_2, + require_torch_gpu, + run_test_in_subprocess, + torch_device, +) from diffusers.utils.torch_utils import randn_tensor -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin -class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +enable_full_determinism() + + +# Will be run via run_test_in_subprocess +def _test_unidiffuser_compile(in_queue, out_queue, timeout): + error = None + try: + inputs = in_queue.get(timeout=timeout) + torch_device = inputs.pop("torch_device") + seed = inputs.pop("seed") + inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed) + + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1") + # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to(torch_device) + + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + pipe.set_progress_bar_config(disable=None) + + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520]) + assert np.abs(image_slice - expected_slice).max() < 1e-1 + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + +class UniDiffuserPipelineFastTests( + PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): pipeline_class = UniDiffuserPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + # vae_latents, not latents, is the argument that corresponds to VAE latent inputs + image_latents_params = frozenset(["vae_latents"]) def get_dummy_components(self): unet = UniDiffuserModel.from_pretrained( @@ -64,7 +118,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): subfolder="image_encoder", ) # From the Stable Diffusion Image Variation pipeline tests - image_processor = CLIPImageProcessor(crop_size=32, size=32) + clip_image_processor = CLIPImageProcessor(crop_size=32, size=32) # image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip") text_tokenizer = GPT2Tokenizer.from_pretrained( @@ -80,7 +134,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "vae": vae, "text_encoder": text_encoder, "image_encoder": image_encoder, - "image_processor": image_processor, + "clip_image_processor": clip_image_processor, "clip_tokenizer": clip_tokenizer, "text_decoder": text_decoder, "text_tokenizer": text_tokenizer, @@ -619,6 +673,19 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase): expected_text_prefix = "An astronaut" assert text[0][: len(expected_text_prefix)] == expected_text_prefix + @unittest.skip(reason="Skip torch.compile test to speed up the slow test suite.") + @require_torch_2 + def test_unidiffuser_compile(self, seed=0): + inputs = self.get_inputs(torch_device, seed=seed, generate_latents=True) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + # Can't pickle a Generator object + del inputs["generator"] + inputs["torch_device"] = torch_device + inputs["seed"] = seed + run_test_in_subprocess(test_case=self, target_func=_test_unidiffuser_compile, inputs=inputs) + @nightly @require_torch_gpu