From c951adef45cf0fbc6b06ecb1caa586fdb7da30bd Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 12:48:53 +0000 Subject: [PATCH] move xattention conditionning out computation out of the denoising loop --- .../pipelines/photon/pipeline_photon.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index ea9844fee2..9e47f8ebc0 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -265,6 +265,7 @@ class PhotonPipeline( self.text_encoder = text_encoder self.tokenizer = tokenizer self.text_preprocessor = TextPreprocessor() + self.default_sample_size = default_sample_size self.register_modules( transformer=transformer, @@ -274,7 +275,7 @@ class PhotonPipeline( vae=vae, ) - self.register_to_config(default_sample_size=default_sample_size) + self.register_to_config(default_sample_size=self.default_sample_size) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -539,13 +540,12 @@ class PhotonPipeline( generated images. """ - # 0. Default height and width from config - default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION) - height = height or default_sample_size - width = width or default_sample_size + # 0. Set height and width + height = height or self.default_sample_size + width = width or self.default_sample_size if use_resolution_binning: - if default_sample_size <= 256: + if self.default_sample_size <= 256: aspect_ratio_bin = ASPECT_RATIO_256_BIN else: aspect_ratio_bin = ASPECT_RATIO_512_BIN @@ -616,7 +616,17 @@ class PhotonPipeline( if accepts_eta: extra_step_kwargs["eta"] = 0.0 - # 6. Denoising loop + # 6. Prepare cross-attention embeddings and masks + if self.do_classifier_free_guidance: + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + else: + ca_embed = text_embeddings + ca_mask = cross_attn_mask + + # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -624,17 +634,10 @@ class PhotonPipeline( # Duplicate latents if using classifier-free guidance if self.do_classifier_free_guidance: latents_in = torch.cat([latents, latents], dim=0) - # Cross-attention batch (uncond, cond) - ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) - ca_mask = None - if cross_attn_mask is not None and uncond_cross_attn_mask is not None: - ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) # Normalize timestep for the transformer t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) else: latents_in = latents - ca_embed = text_embeddings - ca_mask = cross_attn_mask # Normalize timestep for the transformer t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)