mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
separated prompt encoding
This commit is contained in:
@@ -359,124 +359,64 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
Encodes a single prompt (positive or negative) into text encoder hidden states.
|
||||
|
||||
This method combines embeddings from both Qwen2.5-VL and CLIP text encoders
|
||||
to create comprehensive text representations for video generation.
|
||||
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
prompt (`str` or `List[str]`):
|
||||
Prompt to be encoded.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
Number of videos to generate per prompt.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length for text encoding.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
|
||||
device (`torch.device`, *optional*):
|
||||
Torch device.
|
||||
dtype (`torch.dtype`, *optional*):
|
||||
Torch dtype.
|
||||
|
||||
Returns:
|
||||
Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information
|
||||
Tuple[Dict[str, torch.Tensor], torch.Tensor]:
|
||||
- A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP)
|
||||
- Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
prompt = [prompt]
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0]
|
||||
batch_size = len(prompt)
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = [prompt_clean(p) for p in prompt]
|
||||
|
||||
prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
dtype=dtype,
|
||||
)
|
||||
prompt_embeds_clip = self._encode_prompt_clip(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds
|
||||
prompt = [prompt_clean(p) for p in prompt]
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
# Encode with Qwen2.5-VL
|
||||
prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt = [prompt_clean(p) for p in negative_prompt]
|
||||
|
||||
negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
dtype=dtype,
|
||||
)
|
||||
negative_prompt_embeds_clip = self._encode_prompt_clip(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds_qwen = None
|
||||
negative_prompt_embeds_clip = None
|
||||
negative_cu_seqlens = None
|
||||
# Encode with CLIP
|
||||
prompt_embeds_clip = self._encode_prompt_clip(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
prompt_embeds_dict = {
|
||||
"text_embeds": prompt_embeds_qwen,
|
||||
"pooled_embed": prompt_embeds_clip,
|
||||
}
|
||||
negative_prompt_embeds_dict = {
|
||||
"text_embeds": negative_prompt_embeds_qwen,
|
||||
"pooled_embed": negative_prompt_embeds_clip,
|
||||
} if do_classifier_free_guidance else None
|
||||
|
||||
return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens
|
||||
return prompt_embeds_dict, prompt_cu_seqlens
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -722,24 +662,43 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
prompt = [prompt]
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
|
||||
prompt_embeds_dict, prompt_cu_seqlens = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
negative_prompt_embeds_dict = None
|
||||
negative_cu_seqlens = None
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
|
||||
elif len(negative_prompt) != len(prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
|
||||
)
|
||||
|
||||
negative_prompt_embeds_dict, negative_cu_seqlens = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
Reference in New Issue
Block a user