mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support different strength for Stable Diffusion TensorRT Inpainting pipeline (#4216)
* Support different strength * run make style
This commit is contained in:
@@ -823,14 +823,14 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
|
||||
return self
|
||||
|
||||
def __initialize_timesteps(self, timesteps, strength):
|
||||
self.scheduler.set_timesteps(timesteps)
|
||||
offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
|
||||
init_timestep = int(timesteps * strength) + offset
|
||||
init_timestep = min(init_timestep, timesteps)
|
||||
t_start = max(timesteps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device)
|
||||
return timesteps, t_start
|
||||
def __initialize_timesteps(self, num_inference_steps, strength):
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
offset = self.scheduler.config.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].to(self.torch_device)
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def __preprocess_images(self, batch_size, images=()):
|
||||
init_images = []
|
||||
@@ -953,7 +953,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
strength: float = 0.75,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -1043,9 +1043,32 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
latent_height = self.image_height // 8
|
||||
latent_width = self.image_width // 8
|
||||
|
||||
# Pre-process input images
|
||||
mask, masked_image, init_image = self.__preprocess_images(
|
||||
batch_size,
|
||||
prepare_mask_and_masked_image(
|
||||
image,
|
||||
mask_image,
|
||||
self.image_height,
|
||||
self.image_width,
|
||||
return_image=True,
|
||||
),
|
||||
)
|
||||
# print(mask)
|
||||
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))
|
||||
mask = torch.cat([mask] * 2)
|
||||
|
||||
# Initialize timesteps
|
||||
timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
|
||||
|
||||
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
||||
is_strength_max = strength == 1.0
|
||||
|
||||
# Pre-initialize latents
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents = self.prepare_latents(
|
||||
latents_outputs = self.prepare_latents(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
self.image_height,
|
||||
@@ -1053,16 +1076,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
torch.float32,
|
||||
self.torch_device,
|
||||
generator,
|
||||
image=init_image,
|
||||
timestep=latent_timestep,
|
||||
is_strength_max=is_strength_max,
|
||||
)
|
||||
|
||||
# Pre-process input images
|
||||
mask, masked_image = self.__preprocess_images(batch_size, prepare_mask_and_masked_image(image, mask_image))
|
||||
# print(mask)
|
||||
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))
|
||||
mask = torch.cat([mask] * 2)
|
||||
|
||||
# Initialize timesteps
|
||||
timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
|
||||
latents = latents_outputs[0]
|
||||
|
||||
# VAE encode masked image
|
||||
masked_latents = self.__encode_image(masked_image)
|
||||
|
||||
Reference in New Issue
Block a user