diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 846add8906..a0ef975d32 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -557,6 +557,7 @@ class QwenImageTransformer2DModel( attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, return_dict: bool = True, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`QwenTransformer2DModel`] forward method. @@ -611,8 +612,8 @@ class QwenImageTransformer2DModel( if guidance is None else self.time_text_embed(timestep, guidance, hidden_states) ) - - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + if image_rotary_emb is None: + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 8a2ee7b88e..8c5698bd60 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -631,6 +631,10 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) + image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + neg_image_rotary_emb = None + if do_true_cfg: + neg_image_rotary_emb = self.transformer.pos_embed(img_shapes, negative_txt_seq_lens, device=latents.device) # 6. Denoising loop self.scheduler.set_begin_index(0) @@ -649,8 +653,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): guidance=guidance, encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, + image_rotary_emb=image_rotary_emb, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -663,8 +666,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): guidance=guidance, encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, + image_rotary_emb=neg_image_rotary_emb, attention_kwargs=self.attention_kwargs, return_dict=False, )[0]