mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
committed by
GitHub
parent
2e7a28652a
commit
857c04cfba
@@ -542,6 +542,20 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
"""Constructs the edit direction to steer the image generation process semantically."""
|
||||
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_embeds(self, prompt: List[str]) -> torch.FloatTensor:
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
embeds = self.text_encoder(input_ids)[0]
|
||||
|
||||
return embeds.mean(0)[None]
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
|
||||
Reference in New Issue
Block a user