1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Pix2Pix] Add utility function (#2385)

* [Pix2Pix] Add utility function

* improve

* update

* Apply suggestions from code review

* uP

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
This commit is contained in:
Patrick von Platen
2023-02-17 11:49:00 +02:00
committed by GitHub
parent abd5dcbbf1
commit 2e0d489a4e

View File

@@ -542,6 +542,26 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
"""Constructs the edit direction to steer the image generation process semantically."""
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
@torch.no_grad()
def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor:
num_prompts = len(prompt)
embeds = []
for i in range(0, num_prompts, batch_size):
prompt_slice = prompt[i : i + batch_size]
input_ids = self.tokenizer(
prompt_slice,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids
input_ids = input_ids.to(self.text_encoder.device)
embeds.append(self.text_encoder(input_ids)[0])
return torch.cat(embeds, dim=0).mean(0)[None]
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(