1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

conditioned CFG

This commit is contained in:
davidb
2025-10-10 09:46:20 +00:00
committed by DavidBert
parent b327b36ad9
commit 60d918d79b

View File

@@ -245,13 +245,13 @@ class PhotonPipeline(
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents"]
_optional_components = []
_optional_components = ["vae"]
def __init__(
self,
transformer: PhotonTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
text_encoder: Union[T5GemmaEncoder],
text_encoder: T5GemmaEncoder,
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None,
default_sample_size: Optional[int] = DEFAULT_RESOLUTION,
@@ -330,6 +330,11 @@ class PhotonPipeline(
"""Compatibility property that returns spatial compression ratio."""
return getattr(self.vae, "spatial_compression_ratio", 8)
@property
def do_classifier_free_guidance(self):
"""Check if classifier-free guidance is enabled based on guidance scale."""
return self._guidance_scale > 1.0
def prepare_latents(
self,
batch_size: int,
@@ -353,49 +358,67 @@ class PhotonPipeline(
latents = latents.to(device)
return latents
def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device):
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: torch.device,
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
):
"""Encode text prompt using standard text encoder and tokenizer."""
if isinstance(prompt, str):
prompt = [prompt]
return self._encode_prompt_standard(prompt, device)
def _encode_prompt_standard(self, prompt: List[str], device: torch.device):
"""Encode prompt using standard text encoder and tokenizer with batch processing."""
# Clean text using modular preprocessor
cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt]
cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt]
all_prompts = cleaned_prompts + cleaned_uncond_prompts
return self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
def _tokenize_prompts(self, prompts: List[str], device: torch.device):
"""Tokenize and clean prompts."""
cleaned = [self.text_preprocessor.clean_text(text) for text in prompts]
tokens = self.tokenizer(
all_prompts,
cleaned,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device)
input_ids = tokens["input_ids"].to(device)
attention_mask = tokens["attention_mask"].bool().to(device)
def _encode_prompt_standard(
self,
prompt: List[str],
device: torch.device,
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
):
"""Encode prompt using standard text encoder and tokenizer with batch processing."""
batch_size = len(prompt)
if do_classifier_free_guidance:
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
prompts_to_encode = negative_prompt + prompt
else:
prompts_to_encode = prompt
input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device)
with torch.no_grad():
emb = self.text_encoder(
embeddings = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
)["last_hidden_state"]
all_embeddings = emb["last_hidden_state"]
# Split back into conditional and unconditional
batch_size = len(prompt)
text_embeddings = all_embeddings[:batch_size]
uncond_text_embeddings = all_embeddings[batch_size:]
cross_attn_mask = attention_mask[:batch_size]
uncond_cross_attn_mask = attention_mask[batch_size:]
if do_classifier_free_guidance:
uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0)
uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0)
else:
text_embeddings = embeddings
cross_attn_mask = attention_mask
uncond_text_embeddings = None
uncond_cross_attn_mask = None
return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
@@ -534,9 +557,11 @@ class PhotonPipeline(
device = self._execution_device
self._guidance_scale = guidance_scale
# 2. Encode input prompt
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
prompt, device
prompt, device, do_classifier_free_guidance=self.do_classifier_free_guidance
)
# 3. Prepare timesteps
@@ -572,17 +597,22 @@ class PhotonPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Duplicate latents for CFG
latents_in = torch.cat([latents, latents], dim=0)
# Cross-attention batch (uncond, cond)
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
ca_mask = None
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
# Duplicate latents if using classifier-free guidance
if self.do_classifier_free_guidance:
latents_in = torch.cat([latents, latents], dim=0)
# Cross-attention batch (uncond, cond)
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
ca_mask = None
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
else:
latents_in = latents
ca_embed = text_embeddings
ca_mask = cross_attn_mask
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
# Process inputs for transformer
img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed)
@@ -597,11 +627,12 @@ class PhotonPipeline(
)
# Convert back to image format
noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape)
noise_pred = seq2img(img_seq, self.transformer.patch_size, latents_in.shape)
# Apply CFG
noise_uncond, noise_text = noise_both.chunk(2, dim=0)
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
if self.do_classifier_free_guidance:
noise_uncond, noise_text = noise_pred.chunk(2, dim=0)
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
# Compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample