mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
conditioned CFG
This commit is contained in:
@@ -245,13 +245,13 @@ class PhotonPipeline(
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents"]
|
||||
_optional_components = []
|
||||
_optional_components = ["vae"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: PhotonTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder: Union[T5GemmaEncoder],
|
||||
text_encoder: T5GemmaEncoder,
|
||||
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
|
||||
vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None,
|
||||
default_sample_size: Optional[int] = DEFAULT_RESOLUTION,
|
||||
@@ -330,6 +330,11 @@ class PhotonPipeline(
|
||||
"""Compatibility property that returns spatial compression ratio."""
|
||||
return getattr(self.vae, "spatial_compression_ratio", 8)
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
"""Check if classifier-free guidance is enabled based on guidance scale."""
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
@@ -353,49 +358,67 @@ class PhotonPipeline(
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device):
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
):
|
||||
"""Encode text prompt using standard text encoder and tokenizer."""
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
return self._encode_prompt_standard(prompt, device)
|
||||
|
||||
def _encode_prompt_standard(self, prompt: List[str], device: torch.device):
|
||||
"""Encode prompt using standard text encoder and tokenizer with batch processing."""
|
||||
# Clean text using modular preprocessor
|
||||
cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt]
|
||||
cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt]
|
||||
|
||||
all_prompts = cleaned_prompts + cleaned_uncond_prompts
|
||||
return self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
|
||||
|
||||
def _tokenize_prompts(self, prompts: List[str], device: torch.device):
|
||||
"""Tokenize and clean prompts."""
|
||||
cleaned = [self.text_preprocessor.clean_text(text) for text in prompts]
|
||||
tokens = self.tokenizer(
|
||||
all_prompts,
|
||||
cleaned,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device)
|
||||
|
||||
input_ids = tokens["input_ids"].to(device)
|
||||
attention_mask = tokens["attention_mask"].bool().to(device)
|
||||
def _encode_prompt_standard(
|
||||
self,
|
||||
prompt: List[str],
|
||||
device: torch.device,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
):
|
||||
"""Encode prompt using standard text encoder and tokenizer with batch processing."""
|
||||
batch_size = len(prompt)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
|
||||
prompts_to_encode = negative_prompt + prompt
|
||||
else:
|
||||
prompts_to_encode = prompt
|
||||
|
||||
input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device)
|
||||
|
||||
with torch.no_grad():
|
||||
emb = self.text_encoder(
|
||||
embeddings = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
)["last_hidden_state"]
|
||||
|
||||
all_embeddings = emb["last_hidden_state"]
|
||||
|
||||
# Split back into conditional and unconditional
|
||||
batch_size = len(prompt)
|
||||
text_embeddings = all_embeddings[:batch_size]
|
||||
uncond_text_embeddings = all_embeddings[batch_size:]
|
||||
|
||||
cross_attn_mask = attention_mask[:batch_size]
|
||||
uncond_cross_attn_mask = attention_mask[batch_size:]
|
||||
if do_classifier_free_guidance:
|
||||
uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0)
|
||||
uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0)
|
||||
else:
|
||||
text_embeddings = embeddings
|
||||
cross_attn_mask = attention_mask
|
||||
uncond_text_embeddings = None
|
||||
uncond_cross_attn_mask = None
|
||||
|
||||
return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
|
||||
|
||||
@@ -534,9 +557,11 @@ class PhotonPipeline(
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 2. Encode input prompt
|
||||
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
|
||||
prompt, device
|
||||
prompt, device, do_classifier_free_guidance=self.do_classifier_free_guidance
|
||||
)
|
||||
|
||||
# 3. Prepare timesteps
|
||||
@@ -572,17 +597,22 @@ class PhotonPipeline(
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Duplicate latents for CFG
|
||||
latents_in = torch.cat([latents, latents], dim=0)
|
||||
|
||||
# Cross-attention batch (uncond, cond)
|
||||
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
|
||||
ca_mask = None
|
||||
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
|
||||
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
|
||||
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
|
||||
# Duplicate latents if using classifier-free guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
latents_in = torch.cat([latents, latents], dim=0)
|
||||
# Cross-attention batch (uncond, cond)
|
||||
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
|
||||
ca_mask = None
|
||||
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
|
||||
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
|
||||
else:
|
||||
latents_in = latents
|
||||
ca_embed = text_embeddings
|
||||
ca_mask = cross_attn_mask
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
|
||||
|
||||
# Process inputs for transformer
|
||||
img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed)
|
||||
@@ -597,11 +627,12 @@ class PhotonPipeline(
|
||||
)
|
||||
|
||||
# Convert back to image format
|
||||
noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape)
|
||||
noise_pred = seq2img(img_seq, self.transformer.patch_size, latents_in.shape)
|
||||
|
||||
# Apply CFG
|
||||
noise_uncond, noise_text = noise_both.chunk(2, dim=0)
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond, noise_text = noise_pred.chunk(2, dim=0)
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
Reference in New Issue
Block a user