1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

move xattention conditionning out computation out of the denoising loop

This commit is contained in:
David Bertoin
2025-10-13 12:48:53 +00:00
committed by DavidBert
parent a74e0b726a
commit c951adef45

View File

@@ -265,6 +265,7 @@ class PhotonPipeline(
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.text_preprocessor = TextPreprocessor()
self.default_sample_size = default_sample_size
self.register_modules(
transformer=transformer,
@@ -274,7 +275,7 @@ class PhotonPipeline(
vae=vae,
)
self.register_to_config(default_sample_size=default_sample_size)
self.register_to_config(default_sample_size=self.default_sample_size)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -539,13 +540,12 @@ class PhotonPipeline(
generated images.
"""
# 0. Default height and width from config
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION)
height = height or default_sample_size
width = width or default_sample_size
# 0. Set height and width
height = height or self.default_sample_size
width = width or self.default_sample_size
if use_resolution_binning:
if default_sample_size <= 256:
if self.default_sample_size <= 256:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
@@ -616,7 +616,17 @@ class PhotonPipeline(
if accepts_eta:
extra_step_kwargs["eta"] = 0.0
# 6. Denoising loop
# 6. Prepare cross-attention embeddings and masks
if self.do_classifier_free_guidance:
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
ca_mask = None
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
else:
ca_embed = text_embeddings
ca_mask = cross_attn_mask
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -624,17 +634,10 @@ class PhotonPipeline(
# Duplicate latents if using classifier-free guidance
if self.do_classifier_free_guidance:
latents_in = torch.cat([latents, latents], dim=0)
# Cross-attention batch (uncond, cond)
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
ca_mask = None
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
else:
latents_in = latents
ca_embed = text_embeddings
ca_mask = cross_attn_mask
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)