mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
feat: make QwenImage family fully compilable again.
Co-authored-by: apolinario <joaopaulo.passos@gmail.com> Co-authored-by: cbensimon <charles@huggingface.co>
This commit is contained in:
@@ -557,6 +557,7 @@ class QwenImageTransformer2DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples=None,
|
||||
return_dict: bool = True,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`QwenTransformer2DModel`] forward method.
|
||||
@@ -611,8 +612,8 @@ class QwenImageTransformer2DModel(
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
if image_rotary_emb is None:
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -631,6 +631,10 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
neg_image_rotary_emb = None
|
||||
if do_true_cfg:
|
||||
neg_image_rotary_emb = self.transformer.pos_embed(img_shapes, negative_txt_seq_lens, device=latents.device)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
@@ -649,8 +653,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -663,8 +666,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
image_rotary_emb=neg_image_rotary_emb,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
Reference in New Issue
Block a user