diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 6ebedd04e8..bdf7e41df9 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -152,6 +152,16 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): tokenizer_2=tokenizer_2, scheduler=scheduler, ) + + self.prompt_template = "\n".join(["<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>"]) + self.prompt_template_encode_start_idx = 129 self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio @@ -276,29 +286,14 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(p) for p in prompt] - # Kandinsky specific prompt template for detailed video description - prompt_template = "\n".join([ - "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", - "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", - "Describe the location of the video, main characters or objects and their action.", - "Describe the dynamism of the video and presented actions.", - "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", - "Describe the visual effects, postprocessing and transitions if they are presented in the video.", - "Pay attention to the order of key actions shown in the scene.<|im_end|>", - "<|im_start|>user\n{}<|im_end|>", - ]) - crop_start = 129 # Position to start cropping from (system prompt tokens) - - full_texts = [prompt_template.format(p) for p in prompt] + full_texts = [self.prompt_template.format(p) for p in prompt] inputs = self.tokenizer( text=full_texts, images=None, videos=None, - max_length=max_sequence_length + crop_start, + max_length=max_sequence_length + self.prompt_template_encode_start_idx, truncation=True, return_tensors="pt", padding=True, @@ -308,11 +303,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): input_ids=inputs["input_ids"], return_dict=True, output_hidden_states=True, - )["hidden_states"][-1][:, crop_start:] + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx:] batch_size = len(prompt) - attention_mask = inputs["attention_mask"][:, crop_start:] + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) @@ -343,8 +338,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): """ device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(p) for p in prompt] inputs = self.tokenizer_2( prompt, @@ -357,7 +350,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] - # duplicate for each generation per prompt batch_size = len(prompt) pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) @@ -421,6 +413,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] if prompt_embeds is None: + prompt = [prompt_clean(p) for p in prompt] + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( prompt=prompt, device=device, @@ -452,6 +446,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) + + negative_prompt = [prompt_clean(p) for p in negative_prompt] negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( prompt=negative_prompt,