mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Pix2Pix] Add utility function (#2385)
* [Pix2Pix] Add utility function * improve * update * Apply suggestions from code review * uP * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
This commit is contained in:
committed by
GitHub
parent
abd5dcbbf1
commit
2e0d489a4e
@@ -542,6 +542,26 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
"""Constructs the edit direction to steer the image generation process semantically."""
|
||||
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor:
|
||||
num_prompts = len(prompt)
|
||||
embeds = []
|
||||
for i in range(0, num_prompts, batch_size):
|
||||
prompt_slice = prompt[i : i + batch_size]
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
prompt_slice,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
input_ids = input_ids.to(self.text_encoder.device)
|
||||
embeds.append(self.text_encoder(input_ids)[0])
|
||||
|
||||
return torch.cat(embeds, dim=0).mean(0)[None]
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
Reference in New Issue
Block a user