mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
move xattention conditionning out computation out of the denoising loop
This commit is contained in:
@@ -265,6 +265,7 @@ class PhotonPipeline(
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.text_preprocessor = TextPreprocessor()
|
||||
self.default_sample_size = default_sample_size
|
||||
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
@@ -274,7 +275,7 @@ class PhotonPipeline(
|
||||
vae=vae,
|
||||
)
|
||||
|
||||
self.register_to_config(default_sample_size=default_sample_size)
|
||||
self.register_to_config(default_sample_size=self.default_sample_size)
|
||||
|
||||
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
@@ -539,13 +540,12 @@ class PhotonPipeline(
|
||||
generated images.
|
||||
"""
|
||||
|
||||
# 0. Default height and width from config
|
||||
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION)
|
||||
height = height or default_sample_size
|
||||
width = width or default_sample_size
|
||||
# 0. Set height and width
|
||||
height = height or self.default_sample_size
|
||||
width = width or self.default_sample_size
|
||||
|
||||
if use_resolution_binning:
|
||||
if default_sample_size <= 256:
|
||||
if self.default_sample_size <= 256:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||
@@ -616,7 +616,17 @@ class PhotonPipeline(
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = 0.0
|
||||
|
||||
# 6. Denoising loop
|
||||
# 6. Prepare cross-attention embeddings and masks
|
||||
if self.do_classifier_free_guidance:
|
||||
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
|
||||
ca_mask = None
|
||||
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
|
||||
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
|
||||
else:
|
||||
ca_embed = text_embeddings
|
||||
ca_mask = cross_attn_mask
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -624,17 +634,10 @@ class PhotonPipeline(
|
||||
# Duplicate latents if using classifier-free guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
latents_in = torch.cat([latents, latents], dim=0)
|
||||
# Cross-attention batch (uncond, cond)
|
||||
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
|
||||
ca_mask = None
|
||||
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
|
||||
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
|
||||
else:
|
||||
latents_in = latents
|
||||
ca_embed = text_embeddings
|
||||
ca_mask = cross_attn_mask
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user