From 2e0d489a4ea2c552b0f0910fa53cc317d8e88b4d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Feb 2023 11:49:00 +0200 Subject: [PATCH] [Pix2Pix] Add utility function (#2385) * [Pix2Pix] Add utility function * improve * update * Apply suggestions from code review * uP * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py --- .../pipeline_stable_diffusion_pix2pix_zero.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index af0a7e2a4b..338f69446e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -542,6 +542,26 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): """Constructs the edit direction to steer the image generation process semantically.""" return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) + @torch.no_grad() + def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor: + num_prompts = len(prompt) + embeds = [] + for i in range(0, num_prompts, batch_size): + prompt_slice = prompt[i : i + batch_size] + + input_ids = self.tokenizer( + prompt_slice, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + input_ids = input_ids.to(self.text_encoder.device) + embeds.append(self.text_encoder(input_ids)[0]) + + return torch.cat(embeds, dim=0).mean(0)[None] + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__(