From e391b789ac97aeda895ade6e71ac637a06ddc719 Mon Sep 17 00:00:00 2001 From: JinK Date: Fri, 4 Aug 2023 04:32:44 +0900 Subject: [PATCH] Support different strength for Stable Diffusion TensorRT Inpainting pipeline (#4216) * Support different strength * run make style --- .../stable_diffusion_tensorrt_inpaint.py | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py index 44f3bf5049..d17a691110 100755 --- a/examples/community/stable_diffusion_tensorrt_inpaint.py +++ b/examples/community/stable_diffusion_tensorrt_inpaint.py @@ -823,14 +823,14 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): return self - def __initialize_timesteps(self, timesteps, strength): - self.scheduler.set_timesteps(timesteps) - offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 - init_timestep = int(timesteps * strength) + offset - init_timestep = min(init_timestep, timesteps) - t_start = max(timesteps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device) - return timesteps, t_start + def __initialize_timesteps(self, num_inference_steps, strength): + self.scheduler.set_timesteps(num_inference_steps) + offset = self.scheduler.config.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].to(self.torch_device) + return timesteps, num_inference_steps - t_start def __preprocess_images(self, batch_size, images=()): init_images = [] @@ -953,7 +953,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, - strength: float = 0.75, + strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -1043,9 +1043,32 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): latent_height = self.image_height // 8 latent_width = self.image_width // 8 + # Pre-process input images + mask, masked_image, init_image = self.__preprocess_images( + batch_size, + prepare_mask_and_masked_image( + image, + mask_image, + self.image_height, + self.image_width, + return_image=True, + ), + ) + # print(mask) + mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width)) + mask = torch.cat([mask] * 2) + + # Initialize timesteps + timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength) + + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + # Pre-initialize latents num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( + latents_outputs = self.prepare_latents( batch_size, num_channels_latents, self.image_height, @@ -1053,16 +1076,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): torch.float32, self.torch_device, generator, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, ) - # Pre-process input images - mask, masked_image = self.__preprocess_images(batch_size, prepare_mask_and_masked_image(image, mask_image)) - # print(mask) - mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width)) - mask = torch.cat([mask] * 2) - - # Initialize timesteps - timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength) + latents = latents_outputs[0] # VAE encode masked image masked_latents = self.__encode_image(masked_image)